Skip to content

Commit e76765c

Browse files
committed
add retry-after header
1 parent cef463d commit e76765c

9 files changed

Lines changed: 209 additions & 149 deletions

File tree

aikido_zen/middleware/asgi.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ async def __call__(self, scope, receive, send):
1919
message = "You are rate limited by Zen."
2020
if result["trigger"] == "ip" and result["ip"]:
2121
message += " (Your IP: " + result["ip"] + ")"
22-
return await send_status_code_and_text(send, (message, 429))
22+
extra_headers = [
23+
(b"retry-after", str(result["retry_after_seconds"]).encode())
24+
]
25+
return await send_status_code_and_text(send, (message, 429), extra_headers)
2326

2427
if result["type"] == "blocked":
2528
return await send_status_code_and_text(
@@ -30,13 +33,16 @@ async def __call__(self, scope, receive, send):
3033
return await self.app(scope, receive, send)
3134

3235

33-
async def send_status_code_and_text(send, pre_response):
36+
async def send_status_code_and_text(send, pre_response, extra_headers=None):
3437
"""Sends a status code and text"""
38+
headers = [(b"content-type", b"text/plain")]
39+
if extra_headers:
40+
headers = headers + extra_headers
3541
await send(
3642
{
3743
"type": "http.response.start",
3844
"status": pre_response[1],
39-
"headers": [(b"content-type", b"text/plain")],
45+
"headers": headers,
4046
}
4147
)
4248
await send(

aikido_zen/middleware/django.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ def __call__(self, request):
2727
message = "You are rate limited by Zen."
2828
if result["trigger"] == "ip" and result["ip"]:
2929
message += " (Your IP: " + result["ip"] + ")"
30-
return self.HttpResponse(message, content_type="text/plain", status=429)
30+
response = self.HttpResponse(message, content_type="text/plain", status=429)
31+
response["Retry-After"] = str(result["retry_after_seconds"])
32+
return response
3133

3234
if result["type"] == "blocked":
3335
return self.HttpResponse(

aikido_zen/middleware/flask.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __call__(self, environ, start_response):
2828
if result["trigger"] == "ip" and result["ip"]:
2929
message += " (Your IP: " + result["ip"] + ")"
3030
res = self.Response(message, mimetype="text/plain", status=429)
31+
res.headers["Retry-After"] = str(result["retry_after_seconds"])
3132
return res(environ, start_response)
3233

3334
if result["type"] == "blocked":

aikido_zen/middleware/init_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,14 @@ def test_cache_comms_with_endpoints():
152152

153153
mock_comms.send_data_to_bg_process.return_value = {
154154
"success": True,
155-
"data": {"block": True, "trigger": "my_trigger"},
155+
"data": {"block": True, "trigger": "my_trigger", "retry_after_seconds": 10},
156156
}
157157
assert thread_cache.stats.rate_limited_hits == 0
158158
assert should_block_request() == {
159159
"block": True,
160160
"ip": "1.1.1.1",
161161
"type": "ratelimited",
162162
"trigger": "my_trigger",
163+
"retry_after_seconds": 10,
163164
}
164165
assert thread_cache.stats.rate_limited_hits == 1

aikido_zen/middleware/should_block_request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def should_block_request():
6969
"type": "ratelimited",
7070
"trigger": ratelimit_res["data"]["trigger"],
7171
"ip": context.remote_address,
72+
"retry_after_seconds": ratelimit_res["data"]["retry_after_seconds"],
7273
}
7374
except Exception as e:
7475
logger.debug("Exception occurred in should_block_request: %s", e)

aikido_zen/ratelimiting/__init__.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,34 +28,46 @@ def should_ratelimit_request(
2828
windows_size_in_ms = int(endpoint["rateLimiting"]["windowSizeInMS"])
2929

3030
if group:
31-
allowed = connection_manager.rate_limiter.is_allowed(
31+
result = connection_manager.rate_limiter.is_allowed(
3232
get_key_for_group(endpoint, group),
3333
windows_size_in_ms,
3434
max_requests,
3535
)
36-
if not allowed:
37-
return {"block": True, "trigger": "group"}
36+
if not result["allowed"]:
37+
return {
38+
"block": True,
39+
"trigger": "group",
40+
"retry_after_seconds": result["retry_after_seconds"],
41+
}
3842

3943
# Do not check IP or user rate limit if group is set
4044
return {"block": False}
4145
if user:
42-
allowed = connection_manager.rate_limiter.is_allowed(
46+
result = connection_manager.rate_limiter.is_allowed(
4347
get_key_for_user(endpoint, user),
4448
windows_size_in_ms,
4549
max_requests,
4650
)
47-
if not allowed:
48-
return {"block": True, "trigger": "user"}
51+
if not result["allowed"]:
52+
return {
53+
"block": True,
54+
"trigger": "user",
55+
"retry_after_seconds": result["retry_after_seconds"],
56+
}
4957
# Do not check IP rate limit if user is set
5058
return {"block": False}
5159
if remote_address:
52-
allowed = connection_manager.rate_limiter.is_allowed(
60+
result = connection_manager.rate_limiter.is_allowed(
5361
get_key_for_ip(endpoint, remote_address),
5462
windows_size_in_ms,
5563
max_requests,
5664
)
57-
if not allowed:
58-
return {"block": True, "trigger": "ip"}
65+
if not result["allowed"]:
66+
return {
67+
"block": True,
68+
"trigger": "ip",
69+
"retry_after_seconds": result["retry_after_seconds"],
70+
}
5971

6072
return {"block": False}
6173

aikido_zen/ratelimiting/init_test.py

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ def test_rate_limits_by_ip():
6060
assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == {
6161
"block": False
6262
}
63-
assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == {
64-
"block": True,
65-
"trigger": "ip",
66-
}
63+
result = should_ratelimit_request(route_metadata, "1.2.3.4", None, cm)
64+
assert result["block"] is True
65+
assert result["trigger"] == "ip"
66+
assert result["retry_after_seconds"] >= 0
6767

6868

6969
def test_rate_limiting_ip_allowed():
@@ -126,10 +126,10 @@ def test_rate_limiting_by_user(user):
126126
assert should_ratelimit_request(route_metadata, "1.2.3.6", user, cm) == {
127127
"block": False
128128
}
129-
assert should_ratelimit_request(route_metadata, "1.2.3.7", user, cm) == {
130-
"block": True,
131-
"trigger": "user",
132-
}
129+
result = should_ratelimit_request(route_metadata, "1.2.3.7", user, cm)
130+
assert result["block"] is True
131+
assert result["trigger"] == "user"
132+
assert result["retry_after_seconds"] >= 0
133133

134134

135135
def test_rate_limiting_with_wildcard():
@@ -160,9 +160,12 @@ def test_rate_limiting_with_wildcard():
160160
) == {"block": False}
161161

162162
# This request should trigger the rate limit
163-
assert should_ratelimit_request(
163+
result = should_ratelimit_request(
164164
create_route_metadata(route="/api/login"), "1.2.3.4", None, cm
165-
) == {"block": True, "trigger": "ip"}
165+
)
166+
assert result["block"] is True
167+
assert result["trigger"] == "ip"
168+
assert result["retry_after_seconds"] >= 0
166169

167170

168171
def test_rate_limiting_with_wildcard2():
@@ -193,10 +196,10 @@ def test_rate_limiting_with_wildcard2():
193196

194197
# This request should trigger the rate limit
195198
metadata = create_route_metadata(route="/api/login", method="GET")
196-
assert should_ratelimit_request(metadata, "1.2.3.4", None, cm) == {
197-
"block": True,
198-
"trigger": "ip",
199-
}
199+
result = should_ratelimit_request(metadata, "1.2.3.4", None, cm)
200+
assert result["block"] is True
201+
assert result["trigger"] == "ip"
202+
assert result["retry_after_seconds"] >= 0
200203

201204

202205
def test_rate_limiting_by_user_with_same_ip():
@@ -228,10 +231,10 @@ def test_rate_limiting_by_user_with_same_ip():
228231
}
229232

230233
# This request should trigger the rate limit
231-
assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm) == {
232-
"block": True,
233-
"trigger": "user",
234-
}
234+
result = should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm)
235+
assert result["block"] is True
236+
assert result["trigger"] == "user"
237+
assert result["retry_after_seconds"] >= 0
235238

236239

237240
def test_rate_limiting_by_user_with_different_ips():
@@ -267,10 +270,10 @@ def test_rate_limiting_by_user_with_different_ips():
267270
}
268271

269272
# This request from second IP should trigger the rate limit
270-
assert should_ratelimit_request(metadata, "4.3.2.1", {"id": "123"}, cm) == {
271-
"block": True,
272-
"trigger": "user",
273-
}
273+
result = should_ratelimit_request(metadata, "4.3.2.1", {"id": "123"}, cm)
274+
assert result["block"] is True
275+
assert result["trigger"] == "user"
276+
assert result["retry_after_seconds"] >= 0
274277

275278

276279
def test_rate_limiting_same_ip_different_users():
@@ -385,10 +388,10 @@ def test_rate_limits_by_user_with_different_ips():
385388
"block": False
386389
}
387390
# This request should trigger the rate limit by group
388-
assert should_ratelimit_request(route_metadata, "4.3.2.1", user, cm, "group1") == {
389-
"block": True,
390-
"trigger": "group",
391-
}
391+
result = should_ratelimit_request(route_metadata, "4.3.2.1", user, cm, "group1")
392+
assert result["block"] is True
393+
assert result["trigger"] == "group"
394+
assert result["retry_after_seconds"] >= 0
392395

393396

394397
def test_rate_limits_different_users_in_same_group():
@@ -420,12 +423,12 @@ def test_rate_limits_different_users_in_same_group():
420423
route_metadata, "1.2.3.4", {"id": "789"}, cm, "group1"
421424
) == {"block": False}
422425
# This request should trigger the rate limit by group
423-
assert should_ratelimit_request(
426+
result = should_ratelimit_request(
424427
route_metadata, "4.3.2.1", {"id": "101112"}, cm, "group1"
425-
) == {
426-
"block": True,
427-
"trigger": "group",
428-
}
428+
)
429+
assert result["block"] is True
430+
assert result["trigger"] == "group"
431+
assert result["retry_after_seconds"] >= 0
429432

430433

431434
def test_works_with_multiple_rate_limit_groups_and_different_users():
@@ -457,30 +460,30 @@ def test_works_with_multiple_rate_limit_groups_and_different_users():
457460
route_metadata, "4.3.2.1", {"id": "101112"}, cm, "group2"
458461
) == {"block": False}
459462
# This request should trigger the rate limit for group1
460-
assert should_ratelimit_request(
463+
result = should_ratelimit_request(
461464
route_metadata, "1.2.3.4", {"id": "789"}, cm, "group1"
462-
) == {
463-
"block": True,
464-
"trigger": "group",
465-
}
465+
)
466+
assert result["block"] is True
467+
assert result["trigger"] == "group"
468+
assert result["retry_after_seconds"] >= 0
466469
# This request should also trigger the rate limit for group1
467-
assert should_ratelimit_request(
470+
result = should_ratelimit_request(
468471
route_metadata, "1.2.3.4", {"id": "4321"}, cm, "group1"
469-
) == {
470-
"block": True,
471-
"trigger": "group",
472-
}
472+
)
473+
assert result["block"] is True
474+
assert result["trigger"] == "group"
475+
assert result["retry_after_seconds"] >= 0
473476
# First request from user 953, group2
474477
assert should_ratelimit_request(
475478
route_metadata, "4.3.2.1", {"id": "953"}, cm, "group2"
476479
) == {"block": False}
477480
# This request should trigger the rate limit for group2
478-
assert should_ratelimit_request(
481+
result = should_ratelimit_request(
479482
route_metadata, "4.3.2.1", {"id": "1563"}, cm, "group2"
480-
) == {
481-
"block": True,
482-
"trigger": "group",
483-
}
483+
)
484+
assert result["block"] is True
485+
assert result["trigger"] == "group"
486+
assert result["retry_after_seconds"] >= 0
484487

485488

486489
def test_rate_limits_by_group_if_user_is_not_set():
@@ -512,10 +515,10 @@ def test_rate_limits_by_group_if_user_is_not_set():
512515
"block": False
513516
}
514517
# This request should trigger the rate limit by group
515-
assert should_ratelimit_request(route_metadata, "4.3.2.1", None, cm, "group1") == {
516-
"block": True,
517-
"trigger": "group",
518-
}
518+
result = should_ratelimit_request(route_metadata, "4.3.2.1", None, cm, "group1")
519+
assert result["block"] is True
520+
assert result["trigger"] == "group"
521+
assert result["retry_after_seconds"] >= 0
519522

520523

521524
def test_does_not_rate_limit_excluded_users():

aikido_zen/ratelimiting/rate_limiter.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Mostly exports the class RateLimiter
33
"""
44

5+
import math
56
from aikido_zen.helpers.get_current_unixtime_ms import get_unixtime_ms
67
from .lru_cache import LRUCache
78

@@ -18,7 +19,8 @@ def __init__(self, max_items, time_to_live_in_ms):
1819

1920
def is_allowed(self, key, window_size_in_ms, max_requests):
2021
"""
21-
Checks if the request is allowed given the history
22+
Checks if the request is allowed given the history.
23+
Returns {"allowed": True} or {"allowed": False, "retry_after_seconds": int}.
2224
"""
2325
current_time = get_unixtime_ms()
2426
request_timestamps = self.rate_limited_items.get(key) or []
@@ -39,5 +41,13 @@ def is_allowed(self, key, window_size_in_ms, max_requests):
3941
request_timestamps.append(current_time)
4042
self.rate_limited_items.set(key, request_timestamps)
4143

42-
# if the total amount of requests in the current window exceeds max requests, we rate-limit
43-
return len(request_timestamps) <= max_requests
44+
if len(request_timestamps) <= max_requests:
45+
return {"allowed": True}
46+
47+
retry_after_ms = max(
48+
0, request_timestamps[0] + window_size_in_ms - current_time
49+
)
50+
return {
51+
"allowed": False,
52+
"retry_after_seconds": math.ceil(retry_after_ms / 1000),
53+
}

0 commit comments

Comments
 (0)