Source code for limits.storage

"""

"""
from abc import abstractmethod, ABCMeta
import inspect

from six.moves import urllib


try:
    from collections import Counter
except ImportError:  # pragma: no cover
    from .backports.counter import Counter  # pragma: no cover

import threading
import time

import six

from .errors import ConfigurationError
from .util import get_dependency

SCHEMES = {}


[docs]def storage_from_string(storage_string, **options): """ factory function to get an instance of the storage class based on the uri of the storage :param storage_string: a string of the form method://host:port :return: an instance of :class:`flask_limiter.storage.Storage` """ scheme = urllib.parse.urlparse(storage_string).scheme if not scheme in SCHEMES: raise ConfigurationError("unknown storage scheme : %s" % storage_string) return SCHEMES[scheme](storage_string, **options)
class StorageRegistry(type): def __new__(mcs, name, bases, dct): storage_scheme = dct.get('STORAGE_SCHEME', None) if not bases == (object,) and not storage_scheme: raise ConfigurationError( "%s is not configured correctly, it must specify a STORAGE_SCHEME class attribute" % name) cls = super(StorageRegistry, mcs).__new__(mcs, name, bases, dct) SCHEMES[storage_scheme] = cls return cls @six.add_metaclass(StorageRegistry) @six.add_metaclass(ABCMeta)
[docs]class Storage(object): """ Base class to extend when implementing a storage backend. """ def __init__(self, uri=None, **options): self.lock = threading.RLock() @abstractmethod
[docs] def incr(self, key, expiry, elastic_expiry=False): """ increments the counter for a given rate limit key :param str key: the key to increment :param int expiry: amount in seconds for the key to expire in :param bool elastic_expiry: whether to keep extending the rate limit window every hit. """ raise NotImplementedError
@abstractmethod
[docs] def get(self, key): """ :param str key: the key to get the counter value for """ raise NotImplementedError
@abstractmethod
[docs] def get_expiry(self, key): """ :param str key: the key to get the expiry for """ raise NotImplementedError
@abstractmethod
[docs] def check(self): """ check if storage is healthy """ raise NotImplementedError
@abstractmethod
[docs] def reset(self): """ reset storage to clear limits """ raise NotImplementedError
class LockableEntry(threading._RLock): __slots__ = ["atime", "expiry"] def __init__(self, expiry): self.atime = time.time() self.expiry = self.atime + expiry super(LockableEntry, self).__init__()
[docs]class MemoryStorage(Storage): """ rate limit storage using :class:`collections.Counter` as an in memory storage for fixed and elastic window strategies, and a simple list to implement moving window strategy. """ STORAGE_SCHEME = "memory" def __init__(self, uri=None, **_): self.storage = Counter() self.expirations = {} self.events = {} self.timer = threading.Timer(0.01, self.__expire_events) self.timer.start() super(MemoryStorage, self).__init__(uri) def __expire_events(self): for key in self.events.keys(): for event in list(self.events[key]): with event: if event.expiry <= time.time() and event in self.events[key]: self.events[key].remove(event) for key in list(self.expirations.keys()): if self.expirations[key] <= time.time(): self.storage.pop(key, None) self.expirations.pop(key, None) def __schedule_expiry(self): if not self.timer.is_alive(): self.timer = threading.Timer(0.01, self.__expire_events) self.timer.start()
[docs] def incr(self, key, expiry, elastic_expiry=False): """ increments the counter for a given rate limit key :param str key: the key to increment :param int expiry: amount in seconds for the key to expire in :param bool elastic_expiry: whether to keep extending the rate limit window every hit. """ self.get(key) self.__schedule_expiry() self.storage[key] += 1 if elastic_expiry or self.storage[key] == 1: self.expirations[key] = time.time() + expiry return self.storage.get(key, 0)
[docs] def get(self, key): """ :param str key: the key to get the counter value for """ if self.expirations.get(key, 0) <= time.time(): self.storage.pop(key, None) self.expirations.pop(key, None) return self.storage.get(key, 0)
[docs] def acquire_entry(self, key, limit, expiry, no_add=False): """ :param str key: rate limit key to acquire an entry in :param int limit: amount of entries allowed :param int expiry: expiry of the entry :param bool no_add: if False an entry is not actually acquired but instead serves as a 'check' :return: True/False """ self.events.setdefault(key, []) self.__schedule_expiry() timestamp = time.time() try: entry = self.events[key][limit - 1] except IndexError: entry = None if entry and entry.atime >= timestamp - expiry: return False else: if not no_add: self.events[key].insert(0, LockableEntry(expiry)) return True
[docs] def get_expiry(self, key): """ :param str key: the key to get the expiry for """ return int(self.expirations.get(key, -1))
[docs] def get_num_acquired(self, key, expiry): """ returns the number of entries already acquired :param str key: rate limit key to acquire an entry in :param int expiry: expiry of the entry """ timestamp = time.time() return len( [k for k in self.events[key] if k.atime >= timestamp - expiry] ) if self.events.get(key) else 0
[docs] def get_moving_window(self, key, limit, expiry): """ returns the starting point and the number of entries in the moving window :param str key: rate limit key :param int expiry: expiry of entry """ timestamp = time.time() acquired = self.get_num_acquired(key, expiry) for item in self.events.get(key, []): if item.atime >= timestamp - expiry: return int(item.atime), acquired return int(timestamp), acquired
[docs] def check(self): """ check if storage is healthy """ return True
def reset(self): self.storage.clear() self.expirations.clear() self.events.clear()
class RedisInteractor(object): SCRIPT_MOVING_WINDOW = """ local items = redis.call('lrange', KEYS[1], 0, tonumber(ARGV[2])) local expiry = tonumber(ARGV[1]) local a = 0 local oldest = nil for idx=1,#items do if tonumber(items[idx]) >= expiry then a = a + 1 if oldest == nil then oldest = tonumber(items[idx]) end else break end end return {oldest, a} """ SCRIPT_ACQUIRE_MOVING_WINDOW = """ local entry = redis.call('lindex', KEYS[1], tonumber(ARGV[2]) - 1) local timestamp = tonumber(ARGV[1]) local expiry = tonumber(ARGV[3]) if entry and tonumber(entry) >= timestamp - expiry then return false end local limit = tonumber(ARGV[2]) local no_add = tonumber(ARGV[4]) if 0 == no_add then redis.call('lpush', KEYS[1], timestamp) redis.call('ltrim', KEYS[1], 0, limit - 1) redis.call('expire', KEYS[1], expiry) end return true """ SCRIPT_CLEAR_KEYS = """ local keys = redis.call('keys', KEYS[1]) local res = 0 for i=1,#keys,5000 do res = res + redis.call('del', unpack(keys, i, math.min(i+4999, #keys))) end return res """ SCRIPT_INCR_EXPIRE = """ local current current = redis.call("incr",KEYS[1]) if tonumber(current) == 1 then redis.call("expire",KEYS[1],ARGV[1]) end return current """ def incr(self, key, expiry, connection, elastic_expiry=False): """ increments the counter for a given rate limit key :param connection: Redis connection :param str key: the key to increment :param int expiry: amount in seconds for the key to expire in """ value = connection.incr(key) if elastic_expiry or value == 1: connection.expire(key, expiry) return value def get(self, key, connection): """ :param connection: Redis connection :param str key: the key to get the counter value for """ return int(connection.get(key) or 0) def get_moving_window(self, key, limit, expiry): """ returns the starting point and the number of entries in the moving window :param str key: rate limit key :param int expiry: expiry of entry """ timestamp = time.time() window = self.lua_moving_window( [key], [int(timestamp - expiry), limit] ) return window or (timestamp, 0) def acquire_entry( self, key, limit, expiry, connection, no_add=False ): """ :param str key: rate limit key to acquire an entry in :param int limit: amount of entries allowed :param int expiry: expiry of the entry :param bool no_add: if False an entry is not actually acquired but instead serves as a 'check' :param connection: Redis connection :return: True/False """ timestamp = time.time() acquired = self.lua_acquire_window( [key], [timestamp, limit, expiry, int(no_add)] ) return bool(acquired) def get_expiry(self, key, connection=None): """ :param str key: the key to get the expiry for :param connection: Redis connection """ return int((connection.ttl(key) or 0) + time.time()) def check(self, connection): """ :param connection: Redis connection check if storage is healthy """ try: return connection.ping() except: # noqa return False
[docs]class RedisStorage(RedisInteractor, Storage): """ rate limit storage with redis as backend """ STORAGE_SCHEME = "redis" def __init__(self, uri, **_): """ :param str uri: uri of the form 'redis://host:port or redis://host:port/db' :raise ConfigurationError: when the redis library is not available or if the redis host cannot be pinged. """ if not get_dependency("redis"): raise ConfigurationError("redis prerequisite not available") # pragma: no cover self.storage = get_dependency("redis").from_url(uri) self.initialize_storage(uri) super(RedisStorage, self).__init__() def initialize_storage(self, uri): if not self.storage.ping(): raise ConfigurationError("unable to connect to redis at %s" % uri) # pragma: no cover self.lua_moving_window = self.storage.register_script( self.SCRIPT_MOVING_WINDOW ) self.lua_acquire_window = self.storage.register_script( self.SCRIPT_ACQUIRE_MOVING_WINDOW ) self.lua_clear_keys = self.storage.register_script( self.SCRIPT_CLEAR_KEYS ) self.lua_incr_expire = self.storage.register_script( RedisStorage.SCRIPT_INCR_EXPIRE )
[docs] def incr(self, key, expiry, elastic_expiry=False): """ increments the counter for a given rate limit key :param str key: the key to increment :param int expiry: amount in seconds for the key to expire in """ if elastic_expiry: return super(RedisStorage, self).incr( key, expiry, self.storage, elastic_expiry ) else: return self.lua_incr_expire([key], [expiry])
[docs] def get(self, key): """ :param str key: the key to get the counter value for """ return super(RedisStorage, self).get(key, self.storage)
[docs] def acquire_entry(self, key, limit, expiry, no_add=False): """ :param str key: rate limit key to acquire an entry in :param int limit: amount of entries allowed :param int expiry: expiry of the entry :param bool no_add: if False an entry is not actually acquired but instead serves as a 'check' :return: True/False """ return super(RedisStorage, self).acquire_entry( key, limit, expiry, self.storage, no_add=no_add )
[docs] def get_expiry(self, key): """ :param str key: the key to get the expiry for """ return super(RedisStorage, self).get_expiry(key, self.storage)
[docs] def check(self): """ check if storage is healthy """ return super(RedisStorage, self).check(self.storage)
[docs] def reset(self): """WARNING, this operation was designed to be fast, but was not tested on a large production based system. Be careful with its usage as it could be slow on very large data sets. This function calls a Lua Script to delete keys prefixed with 'LIMITER' in block of 5000.""" cleared = self.lua_clear_keys(['LIMITER*']) return cleared
class RedisSSLStorage(RedisStorage): """ rate limit storage with redis as backend using SSL connection """ STORAGE_SCHEME = "rediss" def __init__(self, uri, **options): """ :param str uri: uri of the form 'rediss://host:port or rediss://host:port/db' :raise ConfigurationError: when the redis library is not available or if the redis host cannot be pinged. """ super(RedisSSLStorage, self).__init__(uri, **options) #noqa
[docs]class RedisSentinelStorage(RedisStorage): """ rate limit storage with redis sentinel as backend """ STORAGE_SCHEME = "redis+sentinel" def __init__(self, uri, **options): """ :param str uri: url of the form 'redis+sentinel://host:port,host:port/service_name' :raise ConfigurationError: when the redis library is not available or if the redis master host cannot be pinged. """ if not get_dependency("redis"): raise ConfigurationError("redis prerequisite not available") # pragma: no cover parsed = urllib.parse.urlparse(uri) sentinel_configuration = [] password = None if parsed.password: password = parsed.password for loc in parsed.netloc[parsed.netloc.find("@")+1:].split(","): host, port = loc.split(":") sentinel_configuration.append((host, int(port))) self.service_name = ( parsed.path.replace("/", "") if parsed.path else options.get("service_name", None) ) if self.service_name is None: raise ConfigurationError("'service_name' not provided") self.sentinel = get_dependency("redis.sentinel").Sentinel( sentinel_configuration, socket_timeout=options.get("socket_timeout", 0.2), password=password ) self.initialize_storage(uri) super(RedisStorage, self).__init__() @property def storage(self): return self.sentinel.master_for(self.service_name) @property def storage_slave(self): return self.sentinel.slave_for(self.service_name)
[docs] def get(self, key): """ :param str key: the key to get the counter value for """ return super(RedisStorage, self).get(key, self.storage_slave)
[docs] def get_expiry(self, key): """ :param str key: the key to get the expiry for """ return super(RedisStorage, self).get_expiry(key, self.storage_slave)
[docs] def check(self): """ check if storage is healthy """ return super(RedisStorage, self).check(self.storage_slave)
class RedisClusterStorage(RedisStorage): """ rate limit storage with redis cluster as backend """ STORAGE_SCHEME = "redis+cluster" def __init__(self, uri, **options): """ :param str uri: url of the form 'redis+cluster://host:port,host:port' :raise ConfigurationError: when the rediscluster library is not available or if the redis host cannot be pinged. """ if not get_dependency("rediscluster"): raise ConfigurationError( "redis-py-cluster prerequisite not available" ) # pragma: no cover parsed = urllib.parse.urlparse(uri) cluster_hosts = [] for loc in parsed.netloc.split(","): host, port = loc.split(":") cluster_hosts.append({"host": host, "port": int(port)}) self.storage = get_dependency("rediscluster").RedisCluster( startup_nodes=cluster_hosts, max_connections=options.get("max_connections", 1000) ) self.initialize_storage(uri) super(RedisStorage, self).__init__() def reset(self): """ Redis Clusters are sharded and deleting across shards can't be done atomically. Because of this, this reset loops over all keys that are prefixed with 'LIMITER' and calls delete on them, one at a time. WARNING, this operation was not tested with extremely large data sets. On a large production based system, care should be taken with its usage as it could be slow on very large data sets""" keys = self.storage.keys('LIMITER*') return sum([self.storage.delete(k.decode('utf-8')) for k in keys])
[docs]class MemcachedStorage(Storage): """ rate limit storage with memcached as backend """ MAX_CAS_RETRIES = 10 STORAGE_SCHEME = "memcached" def __init__(self, uri, **options): """ :param str uri: memcached location of the form 'memcached://host:port,host:port' :raise ConfigurationError: when pymemcached is not available """ parsed = urllib.parse.urlparse(uri) self.cluster = [] for loc in parsed.netloc.split(","): host, port = loc.split(":") self.cluster.append((host, int(port))) self.library = options.get('library', 'pymemcache.client') self.client_getter = options.get('client_getter', self.get_client) if not get_dependency(self.library): raise ConfigurationError("memcached prerequisite not available." " please install %s" % self.library) # pragma: no cover self.local_storage = threading.local() self.local_storage.storage = None
[docs] def get_client(self, module, hosts): """ returns a memcached client. :param module: the memcached module :param hosts: list of memcached hosts :return: """ return module.Client(*hosts)
def call_memcached_func(self, func, *args, **kwargs): if 'noreply' in kwargs: if 'noreply' not in inspect.getargspec(func).args: kwargs.pop('noreply') # noqa return func(*args, **kwargs) @property def storage(self): """ lazily creates a memcached client instance using a thread local """ if not (hasattr(self.local_storage, "storage") and self.local_storage.storage): self.local_storage.storage = self.client_getter(get_dependency(self.library), self.cluster) return self.local_storage.storage
[docs] def get(self, key): """ :param str key: the key to get the counter value for """ return int(self.storage.get(key) or 0)
[docs] def incr(self, key, expiry, elastic_expiry=False): """ increments the counter for a given rate limit key :param str key: the key to increment :param int expiry: amount in seconds for the key to expire in :param bool elastic_expiry: whether to keep extending the rate limit window every hit. """ if not self.call_memcached_func(self.storage.add, key, 1, expiry, noreply=False): if elastic_expiry: value, cas = self.storage.gets(key) retry = 0 while ( not self.call_memcached_func(self.storage.cas, key, int(value or 0)+1, cas, expiry) and retry < self.MAX_CAS_RETRIES ): value, cas = self.storage.gets(key) retry += 1 self.call_memcached_func(self.storage.set, key + "/expires", expiry + time.time(), expire=expiry, noreply=False) return int(value or 0) + 1 else: return self.storage.incr(key, 1) self.call_memcached_func(self.storage.set, key + "/expires", expiry + time.time(), expire=expiry, noreply=False) return 1
[docs] def get_expiry(self, key): """ :param str key: the key to get the expiry for """ return int(float(self.storage.get(key + "/expires") or time.time()))
[docs] def check(self): """ check if storage is healthy """ try: self.call_memcached_func(self.storage.stats) return True except: # noqa return False