Over the years, I've written several different rate limiting methods using Redis for both commercial and personal projects. This two-part tutorial intends to cover two different but related methods of performing rate limiting in Redis using standard Redis commands and Lua scripting. Each method expands the number of use-cases for rate limiting, and cleans up some of the rougher edges of previous rate limiters.
This post assumes some experience with Python and Redis, and to a lesser extent Lua, but new users still reading docs should be okay.
Most uses of rate limiting on the web today are generally intended to limit the effect that someone can have on a given platform. Whether it is API limits at Twitter, posting limits at Reddit, or posting limits at StackOverflow, some limit resource utilization, and others limit the effect a spammer account can have. Whatever the reason, let's start with saying that we need to count actions as they happen, and we need to prevent an action from happening if the user has reached or gone over their limit. Let's start with the plan of building a rate limiter for an API where we need to restrict users to 240 requests per hour per user.
We know that we need to count and limit a user, so let's get some utility code out of the way. First, we need to have a function that gives us one or more identifiers for the user performing an action. Sometimes that is just a user id, other times it's the remote IP address; I usually use both when available, and at least IP address if the user hasn't logged in yet. Below is a function that gets the IP address and user id (when available) using Flask with the Flask-Login plugin.
from flask import g, request def get_identifiers(): ret = ['ip:' + request.remote_addr] if g.user.is_authenticated(): ret.append('user:%s'%g.user.get_id()) return ret
Now that we have a function that returns a list of identifiers for an action, let's start counting and limiting. One of the simplest rate limiting methods available in Redis starts by taking the times of the actions as they happen, and buckets actions into ranges of times, counting them as they occur. If the number of actions in a bucket exceeds the limit, we don't allow the action. Below is a function that performs the rate limiting using an automatically-expiring counter that uses 1 hour buckets.
import time def over_limit(conn, duration=3600, limit=240): bucket = ':%i:%i'%(duration, time.time() // duration) for id in get_identifiers(): key = id + bucket count = conn.incr(key) conn.expire(key, duration) if count > limit: return True return False
This function shouldn't be too hard to understand; for each identifier we increment the appropriate key in Redis, set the key to expire in an hour, and if the count is more than the limit, we return True, signifying that we are over the limit. Otherwise we return False.
And that's it. Well, sort of. This gets us past our initial goal of having a basic rate limiter to limit each user to 240 requests per hour. But reality has a tendency to catch us when we aren't looking, and clients using the API have noticed that their limit is reset at the top of every hour. Now users have started making all 240 requests in the first few seconds they can, so all of our work limiting requests is wasted, right?
Our initial rate limiting on a per-hour basis was successful in that it limited users on an hourly basis, but users started using all of their API requests as soon as they could (at the beginning of the hour). Looking at the problem, it seems almost obvious that in addition to a per-hour rate limit, we should probably also have a per-second and/or per-minute rate limit to smooth out peak request rates.
Let's say that we determined that 10 requests per second, 120 requests per minute, and 240 requests per hour were fair enough to our users, and let us better distribute requests over time. We could simply re-use our earlier
over_limit() function to offer this functionality.
def over_limit_multi(conn, limits=[(1, 10), (60, 120), (3600, 240)]): for duration, limit in limits: if over_limit(conn, duration, limit): return True return False
This will work for our intended use, but with 3 rate limit calls, which can result in two counter updates and two expire calls (one for each of IP and user keys), and we may need to perform 12 total round trips to Redis just to say whether someone is over their limit. One common method of minimizing the number of round trips to Redis is to use what is called 'pipelining'. Pipelining in the Redis context will send multiple commands to Redis in a single round trip, which can reduce overall latency.
over_limit() function is written in such a way that we could easily replace our
EXPIRE calls with a single pipelined request to increment the count and update the key expiration. The updated function can be seen below, and cuts our number of round trips from 12 to 6 when combined with
def over_limit(conn, duration=3600, limit=240): pipe = conn.pipeline(transaction=True) bucket = ':%i:%i'%(duration, time.time() // duration) for id in get_identifiers(): key = id + bucket pipe.incr(key) pipe.expire(key, duration) if pipe.execute() > limit: return True return False
Halving the number of round trips to Redis is great, but we are still performing 6 round trips just to say whether a user can make an API call. We could write a replacement
over_limit_multi() that makes all increment and expire operations at once, checking the limits after, but the obvious implementation actually has a counting bug that can prevent users from being able to make 240 successful requests in an hour (in the worst-case, a client may experience 10 successful requests in an hour, despite making over 100 requests per second for the entire hour). This counting bug can be fixed with a second round trip to Redis, but lets instead shift our logic into Redis.
Instead of trying to fix a fully pipelined version, we can use the ability to execute Lua scripts inside Redis to perform the same operation while also keeping to one round trip. The specific operations we are going to perform in Lua are almost the exact same operations as we were originally performing in Python. We are going to iterate over the limits themselves, and for each identifier, we are going to increment a counter, update the expiration time of the updated counter, then check to see if we are over the limit. We will also use a small Python wrapper around our Lua to handle argument conversion and to hide the details of script loading.
import json def over_limit_multi_lua(conn, limits=[(1, 10), (60, 125), (3600, 250)]): if not hasattr(conn, 'over_limit_multi_lua'): conn.over_limit_multi_lua = conn.register_script(over_limit_multi_lua_) return conn.over_limit_multi_lua( keys=get_identifiers(), args=[json.dumps(limits), time.time()]) over_limit_multi_lua_ = ''' local limits = cjson.decode(ARGV) local now = tonumber(ARGV) for i, limit in ipairs(limits) do local duration = limit local bucket = ':' .. duration .. ':' .. math.floor(now / duration) for j, id in ipairs(KEYS) do local key = id .. bucket local count = redis.call('INCR', key) redis.call('EXPIRE', key, duration) if tonumber(count) > limit then return 1 end end end return 0 '''
With the section of code starting with 'local bucket', you will notice that our Lua looks very much like and performs the same operations as our original over_limit() function, with the remaining code handling argument unpacking and iterating over the individual limits.
At this point we have built a rate limiting method that handles multiple levels of timing granularity, can handle multiple identifiers for a single user, and can be performed in a single round trip between the client and Redis. We started from a single-bucket rate limiter to a rate limiter that can evaluate multiple limits simultaneously.
Any of the rate limiting functions discussed in this post are usable for many different applications. In part two, I'll cover a different way of approaching rate limiting, which rounds out the remaining rough edges in our rate limiter. Read it here.
Please note: due to the way Binpress handles commenting, I can no longer comment on this post; please post any replies on my blog instead.