Skip to content

Commit 2b9e8f2

Browse files
committed
Add FastAPI integration, tracer, and utils
Add FastAPI middleware and decorator for rate limiting (ratelink/integrations/fastapi.py) including a RateLimitExceeded exception and an exception handler. Introduce observability/tracer.py with a Tracer protocol, NoOpTracer, OpenTelemetryTracer (conditional), RateLimiterTracer helpers, and create_tracer factory. Add package-level exports and version in ratelink/utils/__init__.py to re-export core types, algorithms, backends, rate limiter classes, observability tools, and integrations. Add common key generator utilities in ratelink/utils/key_generators.py (by_ip, by_user_id, by_api_key, by_route, by_endpoint, composite_key, by_session, custom_key) to standardize rate-limiting keys.
1 parent ea70c8e commit 2b9e8f2

4 files changed

Lines changed: 531 additions & 0 deletions

File tree

ratelink/integrations/fastapi.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import functools
2+
from typing import Any, Callable, Optional
3+
from starlette.middleware.base import BaseHTTPMiddleware
4+
from starlette.requests import Request
5+
from starlette.responses import JSONResponse, Response
6+
from ratelink.utils.key_generators import KeyGeneratorFunc, by_ip
7+
8+
class FastAPIRateLimitMiddleware(BaseHTTPMiddleware):
9+
def __init__(
10+
self,
11+
app,
12+
limiter: Any,
13+
key_generator: Optional[KeyGeneratorFunc] = None,
14+
skip_paths: Optional[list] = None
15+
):
16+
super().__init__(app)
17+
self.limiter = limiter
18+
self.key_generator = key_generator or by_ip()
19+
self.skip_paths = set(skip_paths or [])
20+
21+
async def dispatch(self, request: Request, call_next):
22+
if request.url.path in self.skip_paths:
23+
return await call_next(request)
24+
25+
key = self.key_generator(request)
26+
27+
allowed, state = self.limiter.check(key)
28+
29+
if not allowed:
30+
retry_after = state.get('retry_after', 0)
31+
return JSONResponse(
32+
status_code=429,
33+
content={
34+
"error": "Rate limit exceeded",
35+
"limit": state.get('limit', 0),
36+
"remaining": 0,
37+
"retry_after": retry_after
38+
},
39+
headers={
40+
"Retry-After": str(int(retry_after)),
41+
"X-RateLimit-Limit": str(state.get('limit', 0)),
42+
"X-RateLimit-Remaining": "0",
43+
"X-RateLimit-Reset": str(int(retry_after))
44+
}
45+
)
46+
47+
response = await call_next(request)
48+
49+
response.headers["X-RateLimit-Limit"] = str(state.get('limit', 0))
50+
response.headers["X-RateLimit-Remaining"] = str(state.get('remaining', 0))
51+
if state.get('reset_after'):
52+
response.headers["X-RateLimit-Reset"] = str(int(state.get('reset_after', 0)))
53+
54+
return response
55+
56+
57+
def rate_limit(
58+
limiter: Any,
59+
limit: Optional[int] = None,
60+
window: Optional[int] = None,
61+
key_generator: Optional[KeyGeneratorFunc] = None
62+
):
63+
if key_generator is None:
64+
key_generator = by_ip()
65+
66+
def decorator(func: Callable) -> Callable:
67+
@functools.wraps(func)
68+
async def wrapper(*args, **kwargs):
69+
request = None
70+
for arg in args:
71+
if isinstance(arg, Request):
72+
request = arg
73+
break
74+
75+
if request is None and 'request' in kwargs:
76+
request = kwargs['request']
77+
78+
if request is None:
79+
return await func(*args, **kwargs)
80+
81+
key = key_generator(request)
82+
83+
allowed, state = limiter.check(key)
84+
85+
if not allowed:
86+
retry_after = state.get('retry_after', 0)
87+
return JSONResponse(
88+
status_code=429,
89+
content={
90+
"error": "Rate limit exceeded",
91+
"limit": state.get('limit', 0),
92+
"remaining": 0,
93+
"retry_after": retry_after
94+
},
95+
headers={
96+
"Retry-After": str(int(retry_after)),
97+
"X-RateLimit-Limit": str(state.get('limit', 0)),
98+
"X-RateLimit-Remaining": "0",
99+
}
100+
)
101+
102+
return await func(*args, **kwargs)
103+
104+
return wrapper
105+
106+
return decorator
107+
108+
class RateLimitExceeded(Exception):
109+
def __init__(self, state: dict):
110+
self.state = state
111+
self.retry_after = state.get('retry_after', 0)
112+
self.limit = state.get('limit', 0)
113+
self.remaining = state.get('remaining', 0)
114+
super().__init__(f"Rate limit exceeded. Retry after {self.retry_after}s")
115+
116+
def setup_exception_handler(app):
117+
@app.exception_handler(RateLimitExceeded)
118+
async def rate_limit_exception_handler(request: Request, exc: RateLimitExceeded):
119+
return JSONResponse(
120+
status_code=429,
121+
content={
122+
"error": "Rate limit exceeded",
123+
"limit": exc.limit,
124+
"remaining": exc.remaining,
125+
"retry_after": exc.retry_after
126+
},
127+
headers={
128+
"Retry-After": str(int(exc.retry_after)),
129+
"X-RateLimit-Limit": str(exc.limit),
130+
"X-RateLimit-Remaining": str(exc.remaining),
131+
}
132+
)

ratelink/observability/tracer.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
from contextlib import contextmanager
2+
from typing import Any, Iterator, Optional, Protocol
3+
4+
class Tracer(Protocol):
5+
@contextmanager
6+
def span(
7+
self,
8+
name: str,
9+
**attributes: Any
10+
) -> Iterator[None]:
11+
"""
12+
Create a tracing span.
13+
14+
Args:
15+
name: Span name
16+
**attributes: Span attributes
17+
18+
Yields:
19+
None
20+
"""
21+
...
22+
23+
24+
class NoOpTracer:
25+
@contextmanager
26+
def span(
27+
self,
28+
name: str,
29+
**attributes: Any
30+
) -> Iterator[None]:
31+
"""Create a no-op span."""
32+
yield
33+
34+
35+
try:
36+
from ratelink.m import trace
37+
from openteletry.trace import Status, StatusCode
38+
OTEL_AVAILABLE = True
39+
except ImportError:
40+
OTEL_AVAILABLE = False
41+
42+
43+
class OpenTelemetryTracer:
44+
def __init__(
45+
self,
46+
service_name: str = "rate-limiter",
47+
tracer_provider: Optional[Any] = None
48+
):
49+
if not OTEL_AVAILABLE:
50+
raise ImportError(
51+
"opentelemetry-api is required for OpenTelemetryTracer. "
52+
"Install with: pip install opentelemetry-api opentelemetry-sdk"
53+
)
54+
55+
if tracer_provider:
56+
self._tracer = tracer_provider.get_tracer(service_name)
57+
else:
58+
self._tracer = trace.get_tracer(service_name)
59+
60+
@contextmanager
61+
def span(
62+
self,
63+
name: str,
64+
**attributes: Any
65+
) -> Iterator[None]:
66+
with self._tracer.start_as_current_span(name) as span:
67+
for key, value in attributes.items():
68+
if value is not None:
69+
span.set_attribute(key, value)
70+
try:
71+
yield
72+
except Exception as e:
73+
span.set_status(Status(StatusCode.ERROR))
74+
span.record_exception(e)
75+
raise
76+
77+
78+
class RateLimiterTracer:
79+
def __init__(self, tracer: Optional[Tracer] = None):
80+
self._tracer = tracer or NoOpTracer()
81+
82+
@contextmanager
83+
def trace_check(
84+
self,
85+
key: str,
86+
algorithm: str,
87+
backend: str,
88+
weight: int = 1
89+
) -> Iterator[None]:
90+
with self._tracer.span(
91+
"rate_limit.check",
92+
key=key,
93+
algorithm=algorithm,
94+
backend=backend,
95+
weight=weight
96+
):
97+
yield
98+
99+
@contextmanager
100+
def trace_backend_operation(
101+
self,
102+
backend: str,
103+
operation: str,
104+
key: Optional[str] = None
105+
) -> Iterator[None]:
106+
attrs = {
107+
"backend": backend,
108+
"operation": operation,
109+
}
110+
if key:
111+
attrs["key"] = key
112+
113+
with self._tracer.span(
114+
f"rate_limit.backend.{operation}",
115+
**attrs
116+
):
117+
yield
118+
119+
@contextmanager
120+
def trace_algorithm(
121+
self,
122+
algorithm: str,
123+
key: str
124+
) -> Iterator[None]:
125+
with self._tracer.span(
126+
f"rate_limit.algorithm.{algorithm}",
127+
algorithm=algorithm,
128+
key=key
129+
):
130+
yield
131+
132+
133+
def create_tracer(
134+
enabled: bool = True,
135+
service_name: str = "rate-limiter",
136+
tracer_provider: Optional[Any] = None
137+
) -> Tracer:
138+
if not enabled:
139+
return NoOpTracer()
140+
141+
if not OTEL_AVAILABLE:
142+
import warnings
143+
warnings.warn(
144+
"OpenTelemetry not available. Install with: "
145+
"pip install opentelemetry-api opentelemetry-sdk"
146+
)
147+
return NoOpTracer()
148+
149+
return OpenTelemetryTracer(
150+
service_name=service_name,
151+
tracer_provider=tracer_provider
152+
)

ratelink/utils/__init__.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from ..core.types import (
2+
AlgorithmType,
3+
BackendType,
4+
WindowType,
5+
RateLimitState,
6+
RateLimitResult,
7+
RateLimitException,
8+
LimitExceeded,
9+
BackendError,
10+
ConfigError,
11+
TimeoutError,
12+
)
13+
from ..core.abstractions import Algorithm, Backend, RateLimiter as RateLimiterABC
14+
from ..algorithms.token_bucket import TokenBucketAlgorithm
15+
from ..algorithms.sliding_window import SlidingWindowAlgorithm
16+
from ..algorithms.leaky_bucket import LeakyBucketAlgorithm
17+
from ..algorithms.fixed_window import FixedWindowAlgorithm
18+
from ..algorithms.gcra import GCRAAlgorithm
19+
from ..algorithms.hierarchical import HierarchicalTokenBucket, FairQueueingAlgorithm
20+
from ..backends.memory import MemoryBackend
21+
from ..backends.multi_region import MultiRegionBackend
22+
from ..backends.custom import CustomBackendInterface
23+
try:
24+
from ..backends.redis import RedisBackend
25+
except ImportError:
26+
pass
27+
try:
28+
from ..backends.postgresql import PostgreSQLBackend
29+
except ImportError:
30+
pass
31+
try:
32+
from ..backends.dynamodb import DynamoDBBackend
33+
except ImportError:
34+
pass
35+
try:
36+
from ..backends.mongodb import MongoDBBackend
37+
except ImportError:
38+
pass
39+
from ..rate_limiter import RateLimiter
40+
from ..config import ConfigLoader, RuleEngine
41+
from ..priority_limiter import PriorityRateLimiter
42+
from ..quota_pool import QuotaPool, SharedQuotaManager
43+
from ..adaptive_limiter import AdaptiveRateLimiter
44+
from ..observability.metrics import (
45+
MetricsCollector,
46+
MetricValue,
47+
HistogramBucket,
48+
PrometheusExporter,
49+
create_prometheus_exporter,
50+
)
51+
from ..observability.logging import AuditLogger
52+
from ..integrations.statsd import StatsDExporter
53+
54+
__version__ = "0.4.0"
55+
56+
__all__ = [
57+
"RateLimiter",
58+
"ConfigLoader",
59+
"RuleEngine",
60+
"PriorityRateLimiter",
61+
"QuotaPool",
62+
"SharedQuotaManager",
63+
"AdaptiveRateLimiter",
64+
"AlgorithmType",
65+
"BackendType",
66+
"WindowType",
67+
"RateLimitState",
68+
"RateLimitResult",
69+
"RateLimitException",
70+
"LimitExceeded",
71+
"BackendError",
72+
"ConfigError",
73+
"TimeoutError",
74+
"Algorithm",
75+
"Backend",
76+
"RateLimiterABC",
77+
"TokenBucketAlgorithm",
78+
"SlidingWindowAlgorithm",
79+
"LeakyBucketAlgorithm",
80+
"FixedWindowAlgorithm",
81+
"GCRAAlgorithm",
82+
"HierarchicalTokenBucket",
83+
"FairQueueingAlgorithm",
84+
"MemoryBackend",
85+
"MultiRegionBackend",
86+
"CustomBackendInterface",
87+
"RedisBackend",
88+
"PostgreSQLBackend",
89+
"DynamoDBBackend",
90+
"MongoDBBackend",
91+
"MetricsCollector",
92+
"MetricValue",
93+
"HistogramBucket",
94+
"PrometheusExporter",
95+
"create_prometheus_exporter",
96+
"AuditLogger",
97+
"StatsDExporter",
98+
]

0 commit comments

Comments
 (0)