Skip to content

Instantly share code, notes, and snippets.

Last active April 17, 2022 09:22
Show Gist options
  • Save josiahcarlson/80584b49da41549a7d5c to your computer and use it in GitHub Desktop.
Save josiahcarlson/80584b49da41549a7d5c to your computer and use it in GitHub Desktop.
Regular and sliding window rate limiting to accompany two blog posts.
Copyright 2014, Josiah Carlson - [email protected]
Released under the MIT license
This module intends to show how to perform standard and sliding-window rate
limits as a companion to the two articles posted on Binpress entitled
"Introduction to rate limiting with Redis", parts 1 and 2:
... which will (or have already been) reposted on my personal blog at least 2
weeks after their original posting:
import json
import time
from flask import g, request
def get_identifiers():
ret = ['ip:' + request.remote_addr]
if g.user.is_authenticated():
ret.append('user:' + g.user.get_id())
return ret
def over_limit(conn, duration=3600, limit=240):
bucket = ':%i:%i'%(duration, time.time() // duration)
for id in get_identifiers():
key = id + bucket
count = conn.incr(key)
conn.expire(key, duration)
if count > limit:
return True
return False
def over_limit_multi(conn, limits=[(1, 10), (60, 120), (3600, 240)]):
for duration, limit in limits:
if over_limit(conn, duration, limit):
return True
return False
def over_limit(conn, duration=3600, limit=240):
# Replaces the earlier over_limit() function and reduces round trips with
# pipelining.
pipe = conn.pipeline(transaction=True)
bucket = ':%i:%i'%(duration, time.time() // duration)
for id in get_identifiers():
key = id + bucket
pipe.expire(key, duration)
if pipe.execute()[0] > limit:
return True
return False
def over_limit_multi_lua(conn, limits=[(1, 10), (60, 120), (3600, 240)]):
if not hasattr(conn, 'over_limit_lua'):
conn.over_limit_lua = conn.register_script(over_limit_multi_lua_)
return conn.over_limit_lua(
keys=get_identifiers(), args=[json.dumps(limits), time.time()])
over_limit_multi_lua_ = '''
local limits = cjson.decode(ARGV[1])
local now = tonumber(ARGV[2])
for i, limit in ipairs(limits) do
local duration = limit[1]
local bucket = ':' .. duration .. ':' .. math.floor(now / duration)
for j, id in ipairs(KEYS) do
local key = id .. bucket
local count ='INCR', key)'EXPIRE', key, duration)
if tonumber(count) > limit[2] then
return 1
return 0
def over_limit_sliding_window(conn, weight=1, limits=[(1, 10), (60, 120), (3600, 240, 60)], redis_time=False):
if not hasattr(conn, 'over_limit_sliding_window_lua'):
conn.over_limit_sliding_window_lua = conn.register_script(over_limit_sliding_window_lua_)
now = conn.time()[0] if redis_time else time.time()
return conn.over_limit_sliding_window_lua(
keys=get_identifiers(), args=[json.dumps(limits), now, weight])
over_limit_sliding_window_lua_ = '''
local limits = cjson.decode(ARGV[1])
local now = tonumber(ARGV[2])
local weight = tonumber(ARGV[3] or '1')
local longest_duration = limits[1][1] or 0
local saved_keys = {}
-- handle cleanup and limit checks
for i, limit in ipairs(limits) do
local duration = limit[1]
longest_duration = math.max(longest_duration, duration)
local precision = limit[3] or duration
precision = math.min(precision, duration)
local blocks = math.ceil(duration / precision)
local saved = {}
table.insert(saved_keys, saved)
saved.block_id = math.floor(now / precision)
saved.trim_before = saved.block_id - blocks + 1
saved.count_key = duration .. ':' .. precision .. ':'
saved.ts_key = saved.count_key .. 'o'
for j, key in ipairs(KEYS) do
local old_ts ='HGET', key, saved.ts_key)
old_ts = old_ts and tonumber(old_ts) or saved.trim_before
if old_ts > now then
-- don't write in the past
return 1
-- discover what needs to be cleaned up
local decr = 0
local dele = {}
local trim = math.min(saved.trim_before, old_ts + blocks)
for old_block = old_ts, trim - 1 do
local bkey = saved.count_key .. old_block
local bcount ='HGET', key, bkey)
if bcount then
decr = decr + tonumber(bcount)
table.insert(dele, bkey)
-- handle cleanup
local cur
if #dele > 0 then'HDEL', key, unpack(dele))
cur ='HINCRBY', key, saved.count_key, -decr)
cur ='HGET', key, saved.count_key)
-- check our limits
if tonumber(cur or '0') + weight > limit[2] then
return 1
-- there is enough resources, update the counts
for i, limit in ipairs(limits) do
local saved = saved_keys[i]
for j, key in ipairs(KEYS) do
-- update the current timestamp, count, and bucket count'HSET', key, saved.ts_key, saved.trim_before)'HINCRBY', key, saved.count_key, weight)'HINCRBY', key, saved.count_key .. saved.block_id, weight)
-- We calculated the longest-duration limit so we can EXPIRE
-- the whole HASH for quick and easy idle-time cleanup :)
if longest_duration > 0 then
for _, key in ipairs(KEYS) do'EXPIRE', key, longest_duration)
return 0
Copy link

ciokan commented Jan 21, 2017

How would you return an actual timestamp instead of 1 to be used in a Retry-After header?

Copy link

The answer for you @ciokan is you need to modify the Lua script to calculate the delay. Right now it just returns whether you need to wait. is the line you are looking for.

Copy link

Hi I have three questions.

Question 1

In over_limit_sliding_window_lua_, should

if old_ts > now then

at here be

if old_ts > saved.block_id then

because old_ts is the oldest block id, not a timestamp?

Question 2


local trim = math.min(saved.trim_before, old_ts + blocks)

at here be

saved.trim_before = math.min(saved.trim_before, old_ts + blocks)

because later when saving the oldest block id the code uses saved.trim_before'HSET', key, saved.ts_key, saved.trim_before)


Question 3

Is the purpose of the code

local trim = math.min(saved.trim_before, old_ts + blocks)

at here to limit the number of blocks to trim to be at most blocks?

Copy link

apmcodes commented Nov 14, 2018


How would you return an actual timestamp instead of 1 to be used in a Retry-After header?

Replace line 157 (return 1) with the below code. We are trying to loop through the present duration blocks and find out the earliest block with a request made and then calculate the time until that request block would become stall and thus allows for new request.

            -- return 1
            local last_attempt
            for last_block = saved.trim_before, saved.block_id, precision do
                local bcount ='HGET', key, saved.count_key .. last_block)
                if (bcount) then
                    last_attempt = last_block
            local next_attempt
            if last_attempt then
                next_attempt = (last_attempt + blocks) * precision
                next_attempt = 0
            return next_attempt

Note: The next_attempt received is UNIX timestamp in seconds and not milliseconds

@josiahcarlson Please review this code for any improvement or bug

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment