1+ import pytest
2+
3+ try :
4+ from fastapi import FastAPI , Request
5+ from fastapi .testclient import TestClient
6+ FASTAPI_AVAILABLE = True
7+ except ImportError :
8+ FASTAPI_AVAILABLE = False
9+
10+ if FASTAPI_AVAILABLE :
11+ from ratelink .integration .fastapi import (
12+ FastAPIRateLimitMiddleware ,
13+ rate_limit ,
14+ )
15+ from ratelink .utils .key_generators import by_ip , by_route , composite_key
16+
17+ from mock_limiter import MockRateLimiter
18+
19+ @pytest .mark .skipif (not FASTAPI_AVAILABLE , reason = "FastAPI not installed" )
20+ class TestFastAPIMiddleware :
21+ def test_middleware_allows_request (self ):
22+ app = FastAPI ()
23+ limiter = MockRateLimiter (should_allow = True )
24+
25+ app .add_middleware (
26+ FastAPIRateLimitMiddleware ,
27+ limiter = limiter ,
28+ key_generator = by_ip ()
29+ )
30+
31+ @app .get ("/test" )
32+ async def test_endpoint ():
33+ return {"status" : "ok" }
34+
35+ client = TestClient (app )
36+ response = client .get ("/test" )
37+
38+ assert response .status_code == 200
39+ assert response .json () == {"status" : "ok" }
40+ assert "X-RateLimit-Limit" in response .headers
41+ assert "X-RateLimit-Remaining" in response .headers
42+
43+ def test_middleware_blocks_request (self ):
44+ app = FastAPI ()
45+ limiter = MockRateLimiter (should_allow = False )
46+
47+ app .add_middleware (
48+ FastAPIRateLimitMiddleware ,
49+ limiter = limiter ,
50+ key_generator = by_ip ()
51+ )
52+
53+ @app .get ("/test" )
54+ async def test_endpoint ():
55+ return {"status" : "ok" }
56+
57+ client = TestClient (app )
58+ response = client .get ("/test" )
59+
60+ assert response .status_code == 429
61+ assert "error" in response .json ()
62+ assert response .json ()["error" ] == "Rate limit exceeded"
63+ assert "Retry-After" in response .headers
64+
65+ def test_middleware_skip_paths (self ):
66+ app = FastAPI ()
67+ limiter = MockRateLimiter (should_allow = False )
68+
69+ app .add_middleware (
70+ FastAPIRateLimitMiddleware ,
71+ limiter = limiter ,
72+ skip_paths = ["/health" ]
73+ )
74+
75+ @app .get ("/health" )
76+ async def health ():
77+ return {"status" : "healthy" }
78+
79+ @app .get ("/api" )
80+ async def api ():
81+ return {"data" : []}
82+
83+ client = TestClient (app )
84+
85+ response = client .get ("/health" )
86+ assert response .status_code == 200
87+
88+ response = client .get ("/api" )
89+ assert response .status_code == 429
90+
91+ def test_middleware_custom_key_generator (self ):
92+ app = FastAPI ()
93+ limiter = MockRateLimiter (should_allow = True )
94+
95+ app .add_middleware (
96+ FastAPIRateLimitMiddleware ,
97+ limiter = limiter ,
98+ key_generator = by_route ()
99+ )
100+
101+ @app .get ("/endpoint1" )
102+ async def endpoint1 ():
103+ return {"n" : 1 }
104+
105+ @app .get ("/endpoint2" )
106+ async def endpoint2 ():
107+ return {"n" : 2 }
108+
109+ client = TestClient (app )
110+ client .get ("/endpoint1" )
111+ client .get ("/endpoint2" )
112+
113+ assert len (limiter .check_calls ) == 2
114+ key1 , key2 = limiter .check_calls [0 ][0 ], limiter .check_calls [1 ][0 ]
115+ assert key1 != key2
116+ assert "route" in key1
117+ assert "route" in key2
118+
119+ @pytest .mark .skipif (not FASTAPI_AVAILABLE , reason = "FastAPI not installed" )
120+ class TestFastAPIDecorator :
121+ def test_decorator_allows_request (self ):
122+ app = FastAPI ()
123+ limiter = MockRateLimiter (should_allow = True )
124+
125+ @app .get ("/test" )
126+ @rate_limit (limiter , key_generator = by_ip ())
127+ async def test_endpoint (request : Request ):
128+ return {"status" : "ok" }
129+
130+ client = TestClient (app )
131+ response = client .get ("/test" )
132+
133+ assert response .status_code == 200
134+ assert response .json () == {"status" : "ok" }
135+
136+ def test_decorator_blocks_request (self ):
137+ app = FastAPI ()
138+ limiter = MockRateLimiter (should_allow = False )
139+
140+ @app .get ("/test" )
141+ @rate_limit (limiter , key_generator = by_ip ())
142+ async def test_endpoint (request : Request ):
143+ return {"status" : "ok" }
144+
145+ client = TestClient (app )
146+ response = client .get ("/test" )
147+
148+ assert response .status_code == 429
149+ assert "error" in response .json ()
150+
151+ def test_decorator_per_endpoint_limits (self ):
152+ app = FastAPI ()
153+ limiter_strict = MockRateLimiter (should_allow = False )
154+ limiter_lenient = MockRateLimiter (should_allow = True )
155+
156+ @app .get ("/strict" )
157+ @rate_limit (limiter_strict , key_generator = by_ip ())
158+ async def strict_endpoint (request : Request ):
159+ return {"endpoint" : "strict" }
160+
161+ @app .get ("/lenient" )
162+ @rate_limit (limiter_lenient , key_generator = by_ip ())
163+ async def lenient_endpoint (request : Request ):
164+ return {"endpoint" : "lenient" }
165+
166+ client = TestClient (app )
167+
168+ response = client .get ("/strict" )
169+ assert response .status_code == 429
170+
171+ response = client .get ("/lenient" )
172+ assert response .status_code == 200
173+
174+ def test_decorator_composite_key (self ):
175+ app = FastAPI ()
176+ limiter = MockRateLimiter (should_allow = True )
177+
178+ @app .get ("/test" )
179+ @rate_limit (limiter , key_generator = composite_key (by_ip (), by_route ()))
180+ async def test_endpoint (request : Request ):
181+ return {"status" : "ok" }
182+
183+ client = TestClient (app )
184+ response = client .get ("/test" )
185+
186+ assert response .status_code == 200
187+ assert len (limiter .check_calls ) == 1
188+ key = limiter .check_calls [0 ][0 ]
189+ assert "ip:" in key
190+ assert "route:" in key
191+
192+ def test_decorator_without_request (self ):
193+ app = FastAPI ()
194+ limiter = MockRateLimiter (should_allow = True )
195+
196+ @app .get ("/test" )
197+ @rate_limit (limiter , key_generator = by_ip ())
198+ async def test_endpoint ():
199+ return {"status" : "ok" }
200+
201+ client = TestClient (app )
202+ response = client .get ("/test" )
203+
204+ assert response .status_code in [200 , 429 ]
205+
206+ def test_response_headers_on_success (self ):
207+ app = FastAPI ()
208+ limiter = MockRateLimiter (should_allow = True , state = {
209+ 'allowed' : True ,
210+ 'remaining' : 75 ,
211+ 'limit' : 100 ,
212+ 'retry_after' : 0 ,
213+ 'reset_after' : 45.0
214+ })
215+
216+ app .add_middleware (
217+ FastAPIRateLimitMiddleware ,
218+ limiter = limiter
219+ )
220+
221+ @app .get ("/test" )
222+ async def test_endpoint ():
223+ return {"status" : "ok" }
224+
225+ client = TestClient (app )
226+ response = client .get ("/test" )
227+
228+ assert response .status_code == 200
229+ assert response .headers ["X-RateLimit-Limit" ] == "100"
230+ assert response .headers ["X-RateLimit-Remaining" ] == "75"
231+
232+ def test_retry_after_header_format (self ):
233+ app = FastAPI ()
234+ limiter = MockRateLimiter (should_allow = False , state = {
235+ 'allowed' : False ,
236+ 'remaining' : 0 ,
237+ 'limit' : 100 ,
238+ 'retry_after' : 45.7 ,
239+ })
240+
241+ app .add_middleware (
242+ FastAPIRateLimitMiddleware ,
243+ limiter = limiter
244+ )
245+
246+ @app .get ("/test" )
247+ async def test_endpoint ():
248+ return {"status" : "ok" }
249+
250+ client = TestClient (app )
251+ response = client .get ("/test" )
252+
253+ assert response .status_code == 429
254+ assert response .headers ["Retry-After" ] == "45"
255+ assert response .json ()["retry_after" ] == 45.7
256+
257+
258+ @pytest .mark .skipif (not FASTAPI_AVAILABLE , reason = "FastAPI not installed" )
259+ class TestFastAPIAsync :
260+ def test_async_endpoint_with_decorator (self ):
261+ app = FastAPI ()
262+ limiter = MockRateLimiter (should_allow = True )
263+
264+ @app .get ("/async" )
265+ @rate_limit (limiter )
266+ async def async_endpoint (request : Request ):
267+ return {"async" : True }
268+
269+ client = TestClient (app )
270+ response = client .get ("/async" )
271+
272+ assert response .status_code == 200
273+ assert response .json () == {"async" : True }
274+
275+ def test_middleware_doesnt_block_async (self ):
276+ app = FastAPI ()
277+ limiter = MockRateLimiter (should_allow = True )
278+
279+ app .add_middleware (
280+ FastAPIRateLimitMiddleware ,
281+ limiter = limiter
282+ )
283+
284+ call_order = []
285+
286+ @app .get ("/test" )
287+ async def test_endpoint ():
288+ call_order .append ("endpoint" )
289+ return {"status" : "ok" }
290+
291+ client = TestClient (app )
292+ response = client .get ("/test" )
293+
294+ assert response .status_code == 200
295+ assert "endpoint" in call_order
0 commit comments