Skip to content

Commit 091e94c

Browse files
committed
Add rate_limit decorator and RateLimitExceeded
Introduce a new rate_limit decorator (ratelink/utils/decorators.py) to apply rate limiting to both sync and async callables. Adds RateLimitExceeded exception carrying retry_after, limit, and remaining metadata, an internal _extract_request helper to locate request objects in args/kwargs, and defaults the key generator to by_ip(). The decorator uses limiter.check(key) to determine allowance (optional limit/window params are accepted but currently unused). Also includes an updated .DS_Store entry.
1 parent 132399b commit 091e94c

2 files changed

Lines changed: 91 additions & 0 deletions

File tree

ratelink/.DS_Store

0 Bytes
Binary file not shown.

ratelink/utils/decorators.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import functools
2+
from typing import Any, Callable, Optional
3+
from ratelink.utils.key_generators import KeyGeneratorFunc, by_ip
4+
5+
class RateLimitExceeded(Exception):
6+
def __init__(
7+
self,
8+
message: str = "Rate limit exceeded",
9+
retry_after: float = 0.0,
10+
limit: int = 0,
11+
remaining: int = 0
12+
):
13+
super().__init__(message)
14+
self.message = message
15+
self.retry_after = retry_after
16+
self.limit = limit
17+
self.remaining = remaining
18+
19+
20+
def rate_limit(
21+
limiter: Any,
22+
limit: Optional[int] = None,
23+
window: Optional[int] = None,
24+
key_generator: Optional[KeyGeneratorFunc] = None
25+
):
26+
if key_generator is None:
27+
key_generator = by_ip()
28+
29+
def decorator(func: Callable) -> Callable:
30+
@functools.wraps(func)
31+
async def async_wrapper(*args, **kwargs):
32+
request = _extract_request(args, kwargs)
33+
34+
if request is None:
35+
return await func(*args, **kwargs)
36+
37+
key = key_generator(request)
38+
39+
if limit is not None and window is not None:
40+
pass
41+
42+
allowed, state = limiter.check(key)
43+
44+
if not allowed:
45+
raise RateLimitExceeded(
46+
retry_after=state.get('retry_after', 0),
47+
limit=state.get('limit', 0),
48+
remaining=state.get('remaining', 0)
49+
)
50+
51+
return await func(*args, **kwargs)
52+
53+
@functools.wraps(func)
54+
def sync_wrapper(*args, **kwargs):
55+
request = _extract_request(args, kwargs)
56+
57+
if request is None:
58+
return func(*args, **kwargs)
59+
60+
key = key_generator(request)
61+
62+
allowed, state = limiter.check(key)
63+
64+
if not allowed:
65+
raise RateLimitExceeded(
66+
retry_after=state.get('retry_after', 0),
67+
limit=state.get('limit', 0),
68+
remaining=state.get('remaining', 0)
69+
)
70+
71+
return func(*args, **kwargs)
72+
73+
if functools.iscoroutinefunction(func):
74+
return async_wrapper
75+
else:
76+
return sync_wrapper
77+
78+
return decorator
79+
80+
81+
def _extract_request(args: tuple, kwargs: dict) -> Optional[Any]:
82+
for key in ['request', 'req']:
83+
if key in kwargs:
84+
return kwargs[key]
85+
86+
if args:
87+
first_arg = args[0]
88+
if hasattr(first_arg, 'headers') or hasattr(first_arg, 'META'):
89+
return first_arg
90+
91+
return None

0 commit comments

Comments
 (0)