Skip to content

Commit bf37545

Browse files
author
rodrigo.nogueira
committed
Add rate limiting middleware example using TCPSite.port
- Add examples/rate_limit_middleware.py with token bucket rate limiter - Add examples/tests/test_rate_limit_middleware.py with tests - Add examples/tests/pytest.ini for test configuration - Add CHANGES/11969.doc.rst changelog fragment
1 parent 601bedb commit bf37545

4 files changed

Lines changed: 310 additions & 0 deletions

File tree

CHANGES/11969.doc.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Added rate-limiting client middleware example (``examples/rate_limit_middleware.py``)
2+
demonstrating token-bucket rate limiting with per-domain support and ``Retry-After``
3+
header handling -- by :user:`rodrigobnogueira`.

examples/rate_limit_middleware.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Client-side rate-limiting middleware example for aiohttp.
4+
5+
Demonstrates how to throttle outgoing requests using a token-bucket
6+
algorithm. This is *not* server-side rate limiting — it limits how
7+
fast the client sends requests so it does not overwhelm upstream
8+
servers or exceed API quotas.
9+
10+
Features:
11+
- Configurable rate and burst size
12+
- Optional per-domain buckets
13+
- Automatic ``Retry-After`` header handling
14+
"""
15+
16+
import asyncio
17+
import logging
18+
import time
19+
from collections import defaultdict, deque
20+
from http import HTTPStatus
21+
22+
from aiohttp import ClientHandlerType, ClientRequest, ClientResponse, ClientSession, web
23+
24+
logging.basicConfig(level=logging.INFO)
25+
_LOGGER = logging.getLogger(__name__)
26+
27+
28+
class TokenBucket:
29+
"""FIFO token-bucket using an ``asyncio.Event`` queue.
30+
31+
Each caller appends its own event to a FIFO queue and waits.
32+
A single ``_schedule`` coroutine services the queue front-to-back,
33+
sleeping until each slot's send time arrives and then unblocking
34+
the corresponding caller. This guarantees strict FIFO ordering
35+
even under high concurrency.
36+
"""
37+
38+
def __init__(self, rate: float, burst: int) -> None:
39+
self._interval = 1.0 / rate
40+
self._burst = burst
41+
# Start *burst* intervals in the past so the first
42+
# ``burst`` acquires are instant.
43+
self._next_send = time.monotonic() - burst * self._interval
44+
self._waiters: deque[asyncio.Event] = deque()
45+
self._scheduling: bool = False
46+
47+
async def acquire(self) -> None:
48+
"""Reserve the next send slot and wait until it arrives."""
49+
event = asyncio.Event()
50+
self._waiters.append(event)
51+
self._ensure_scheduling()
52+
await event.wait()
53+
54+
def _ensure_scheduling(self) -> None:
55+
"""Start the scheduler loop if it is not already running."""
56+
if not self._scheduling:
57+
self._scheduling = True
58+
_ = asyncio.ensure_future(self._schedule())
59+
60+
async def _schedule(self) -> None:
61+
"""Service waiters in FIFO order, one slot at a time."""
62+
while self._waiters:
63+
now = time.monotonic()
64+
# Cap drift so idle periods never accumulate
65+
# more than *burst* free slots.
66+
self._next_send = max(self._next_send, now - self._burst * self._interval)
67+
self._next_send += self._interval
68+
delay = self._next_send - now
69+
if delay > 0:
70+
await asyncio.sleep(delay)
71+
self._waiters.popleft().set()
72+
self._scheduling = False
73+
74+
75+
class RateLimitMiddleware:
76+
"""Middleware that rate limits requests using token bucket algorithm."""
77+
78+
rate: float
79+
burst: int
80+
per_domain: bool
81+
respect_retry_after: bool
82+
83+
def __init__(
84+
self,
85+
rate: float = 10.0,
86+
burst: int = 10,
87+
per_domain: bool = False,
88+
respect_retry_after: bool = True,
89+
) -> None:
90+
self.rate = rate
91+
self.burst = burst
92+
self.per_domain = per_domain
93+
self.respect_retry_after = respect_retry_after
94+
self._global_bucket = TokenBucket(rate, burst)
95+
self._domain_buckets: dict[str, TokenBucket] = defaultdict(
96+
lambda: TokenBucket(rate, burst)
97+
)
98+
99+
def _get_bucket(self, request: ClientRequest) -> TokenBucket:
100+
if self.per_domain:
101+
domain = request.url.host or "unknown"
102+
return self._domain_buckets[domain]
103+
return self._global_bucket
104+
105+
async def _handle_retry_after(self, response: ClientResponse) -> None:
106+
if response.status != HTTPStatus.TOO_MANY_REQUESTS:
107+
return
108+
retry_after = response.headers.get("Retry-After")
109+
if retry_after:
110+
try:
111+
wait_seconds = float(retry_after)
112+
_LOGGER.info("Server requested Retry-After: %ss", wait_seconds)
113+
await asyncio.sleep(wait_seconds)
114+
except ValueError:
115+
_LOGGER.debug(
116+
"Retry-After is not a number (likely HTTP-date): %s", retry_after
117+
)
118+
119+
async def __call__(
120+
self,
121+
request: ClientRequest,
122+
handler: ClientHandlerType,
123+
) -> ClientResponse:
124+
"""Execute request with rate limiting."""
125+
bucket = self._get_bucket(request)
126+
await bucket.acquire()
127+
128+
response = await handler(request)
129+
130+
if self.respect_retry_after:
131+
await self._handle_retry_after(response)
132+
133+
return response
134+
135+
136+
# ------------------------------------------------------------------
137+
# Self-contained demo (no external dependencies)
138+
async def _demo_handler(_request: web.Request) -> web.Response:
139+
return web.Response(text="OK")
140+
141+
142+
async def main() -> None:
143+
app = web.Application()
144+
_ = app.router.add_get("/get", _demo_handler)
145+
runner = web.AppRunner(app)
146+
await runner.setup()
147+
site = web.TCPSite(runner, "127.0.0.1", 0)
148+
await site.start()
149+
150+
port: int = site.port
151+
rate_limit = RateLimitMiddleware(rate=5.0, burst=2)
152+
start = time.monotonic()
153+
154+
try:
155+
async with ClientSession(
156+
base_url=f"http://127.0.0.1:{port}",
157+
middlewares=(rate_limit,),
158+
) as session:
159+
for i in range(5):
160+
async with session.get("/get") as resp:
161+
elapsed = time.monotonic() - start
162+
print(f"Request {i + 1}: {resp.status} at t={elapsed:.2f}s")
163+
finally:
164+
await runner.cleanup()
165+
166+
167+
if __name__ == "__main__":
168+
asyncio.run(main())

examples/tests/pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[pytest]
2+
pythonpath = ..
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""
2+
Tests for the rate_limit_middleware.py example.
3+
4+
Run with:
5+
pytest examples/tests/test_rate_limit_middleware.py -v
6+
"""
7+
8+
import asyncio
9+
import time
10+
11+
import pytest
12+
from rate_limit_middleware import RateLimitMiddleware, TokenBucket
13+
14+
from aiohttp import web
15+
from aiohttp.pytest_plugin import AiohttpClient
16+
17+
18+
async def _ok_handler(request: web.Request) -> web.Response:
19+
return web.Response(text="OK")
20+
21+
22+
def _make_app() -> web.Application:
23+
app = web.Application()
24+
app.router.add_get("/api", _ok_handler)
25+
return app
26+
27+
28+
@pytest.mark.asyncio
29+
async def test_token_bucket_allows_burst() -> None:
30+
"""Tokens up to burst size should be available immediately."""
31+
bucket = TokenBucket(rate=10.0, burst=3)
32+
start = time.monotonic()
33+
for _ in range(3):
34+
await bucket.acquire()
35+
elapsed = time.monotonic() - start
36+
# All three should be nearly instant (burst)
37+
assert elapsed < 0.05
38+
39+
40+
@pytest.mark.asyncio
41+
async def test_token_bucket_refills_after_idle() -> None:
42+
"""After draining, idle time should replenish burst slots."""
43+
bucket = TokenBucket(rate=100.0, burst=1)
44+
await bucket.acquire()
45+
await asyncio.sleep(0.05)
46+
start = time.monotonic()
47+
await bucket.acquire()
48+
elapsed = time.monotonic() - start
49+
# Should be near-instant because idle refilled the slot
50+
assert elapsed < 0.02
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_token_bucket_fifo_ordering() -> None:
55+
"""Concurrent acquires should be served in FIFO order."""
56+
bucket = TokenBucket(rate=100.0, burst=1)
57+
order: list[int] = []
58+
59+
async def numbered_acquire(n: int) -> None:
60+
await bucket.acquire()
61+
order.append(n)
62+
63+
tasks = [asyncio.create_task(numbered_acquire(i)) for i in range(3)]
64+
await asyncio.gather(*tasks)
65+
assert order == [0, 1, 2]
66+
67+
68+
@pytest.mark.asyncio
69+
async def test_rate_limit_middleware_throttles(
70+
aiohttp_client: AiohttpClient,
71+
) -> None:
72+
"""Global middleware should throttle requests beyond burst."""
73+
middleware = RateLimitMiddleware(rate=50.0, burst=2)
74+
client = await aiohttp_client(_make_app(), middlewares=(middleware,))
75+
76+
start = time.monotonic()
77+
for _ in range(4):
78+
resp = await client.get("/api")
79+
assert resp.status == 200
80+
elapsed = time.monotonic() - start
81+
82+
# 2 burst + 2 throttled at 50/s ≈ 0.04s minimum wait.
83+
# Upper bound (0.5s) catches hangs or accidental double-sleeps
84+
# while staying generous enough for CI environments.
85+
assert 0.02 <= elapsed < 0.5
86+
87+
88+
@pytest.mark.asyncio
89+
async def test_rate_limit_middleware_per_domain(
90+
aiohttp_client: AiohttpClient,
91+
) -> None:
92+
"""Per-domain middleware should isolate buckets per host."""
93+
middleware = RateLimitMiddleware(rate=100.0, burst=1, per_domain=True)
94+
client = await aiohttp_client(_make_app(), middlewares=(middleware,))
95+
96+
start = time.monotonic()
97+
# Same host, so they share a bucket — second request should wait
98+
resp1 = await client.get("/api")
99+
resp2 = await client.get("/api")
100+
elapsed = time.monotonic() - start
101+
102+
assert resp1.status == 200
103+
assert resp2.status == 200
104+
# Upper bound catches unexpected delays without being flaky on CI
105+
assert 0.005 <= elapsed < 0.5
106+
107+
108+
@pytest.mark.asyncio
109+
async def test_rate_limit_middleware_respects_retry_after(
110+
aiohttp_client: AiohttpClient,
111+
) -> None:
112+
"""Middleware should sleep when server returns 429 + Retry-After."""
113+
call_count = 0
114+
115+
async def rate_limited_handler(request: web.Request) -> web.Response:
116+
nonlocal call_count
117+
call_count += 1
118+
if call_count <= 1:
119+
return web.Response(
120+
status=429,
121+
headers={"Retry-After": "0.1"},
122+
)
123+
return web.Response(text="OK")
124+
125+
app = web.Application()
126+
app.router.add_get("/api", rate_limited_handler)
127+
128+
middleware = RateLimitMiddleware(rate=100.0, burst=10, respect_retry_after=True)
129+
client = await aiohttp_client(app, middlewares=(middleware,))
130+
131+
start = time.monotonic()
132+
resp = await client.get("/api")
133+
elapsed = time.monotonic() - start
134+
135+
assert resp.status == 429
136+
# Upper bound catches unexpected delays without being flaky on CI
137+
assert 0.08 <= elapsed < 0.5

0 commit comments

Comments
 (0)