1+ from asyncio import Lock
2+ from collections import OrderedDict
13from collections .abc import Awaitable , Callable
4+ from dataclasses import dataclass
5+ from hashlib import sha256
6+ from inspect import isawaitable
27from math import ceil
3- from typing import TypeAlias
8+ from typing import TypeAlias , TypeVar
49
510from fastapi import Request , Response
611from fastapi_pagination .utils import is_async_callable
7- from pyrate_limiter import AbstractBucket , Limiter , Rate
12+ from pyrate_limiter import AbstractBucket , BucketFactory , Limiter , Rate , RateItem
813from pyrate_limiter .buckets import RedisBucket
14+ from redis .asyncio import Redis
915from starlette .concurrency import run_in_threadpool
1016
1117from backend .common .exception import errors
1824CallbackCallable : TypeAlias = (
1925 Callable [[Request , Response , int ], None ] | Callable [[Request , Response , int ], Awaitable [None ]]
2026)
27+ T = TypeVar ('T' )
28+
29+ REQUEST_LIMITER_BUCKET_CACHE_MAX_SIZE = 4096
30+ REQUEST_LIMITER_BUCKET_CACHE_BUFFER_MS = 10_000
31+
32+
33+ @dataclass (slots = True )
34+ class RedisBucketState :
35+ """Redis bucket 缓存状态"""
36+
37+ bucket : RedisBucket
38+ last_seen : int
39+
40+
41+ async def _maybe_await (value : T | Awaitable [T ]) -> T :
42+ """
43+ 兼容同步值和异步值
44+
45+ :param value: 同步值或 Awaitable 对象
46+ :return:
47+ """
48+ if isawaitable (value ):
49+ return await value
50+ return value
51+
52+
53+ async def _redis_time_ms (redis : Redis ) -> int :
54+ """
55+ 获取 Redis 服务端当前时间
56+
57+ :return:
58+ """
59+ seconds , microseconds = await redis .time ()
60+ return seconds * 1000 + microseconds // 1000
61+
62+
63+ class RedisTimeBucket (RedisBucket ):
64+ """使用 Redis 服务端时间的 Redis bucket"""
65+
66+ async def now (self ) -> int :
67+ """获取 Redis 服务端当前时间"""
68+ return await _redis_time_ms (self .redis )
69+
70+
71+ class RedisBucketFactory (BucketFactory ):
72+ """按请求标识符路由到独立 Redis bucket"""
73+
74+ def __init__ (
75+ self ,
76+ rates : list [Rate ],
77+ bucket_key : str ,
78+ max_cache_size : int = REQUEST_LIMITER_BUCKET_CACHE_MAX_SIZE ,
79+ ) -> None :
80+ """
81+ 初始化 Redis bucket 工厂
82+
83+ :param rates: pyrate_limiter Rate 对象列表
84+ :param bucket_key: Redis key 前缀
85+ :param max_cache_size: 本地 bucket 缓存最大数量
86+ :return:
87+ """
88+ self .rates = rates
89+ self .bucket_key = f'{ bucket_key } :{ self ._rate_key (rates )} '
90+ self .max_cache_size = max (1 , max_cache_size )
91+ self .cache_ttl = max (rate .interval for rate in rates ) + REQUEST_LIMITER_BUCKET_CACHE_BUFFER_MS
92+ self .lock = Lock ()
93+ self .buckets : OrderedDict [str , RedisBucketState ] = OrderedDict ()
94+
95+ async def wrap_item (self , name : str , weight : int = 1 ) -> RateItem :
96+ """
97+ 包装限流项
98+
99+ :param name: 限流标识符
100+ :param weight: 请求权重
101+ :return:
102+ """
103+ return RateItem (name , await _redis_time_ms (redis_client ), weight = weight )
104+
105+ async def get (self , item : RateItem ) -> RedisBucket :
106+ """
107+ 获取标识符对应的 Redis bucket
108+
109+ :param item: 限流项
110+ :return:
111+ """
112+ bucket_key = self ._bucket_key (item .name )
113+ now = await _redis_time_ms (redis_client )
114+
115+ async with self .lock :
116+ state = self .buckets .get (bucket_key )
117+ if state is not None :
118+ state .last_seen = now
119+ self .buckets .move_to_end (bucket_key )
120+ return state .bucket
121+
122+ bucket_result = RedisTimeBucket .init (
123+ rates = self .rates ,
124+ redis = redis_client ,
125+ bucket_key = bucket_key ,
126+ )
127+ bucket = await _maybe_await (bucket_result )
128+ self .buckets [bucket_key ] = RedisBucketState (bucket = bucket , last_seen = now )
129+ self .schedule_leak (bucket )
130+ await self ._evict (now )
131+ return bucket
132+
133+ async def get_bucket (self , name : str ) -> RedisBucket :
134+ """
135+ 获取标识符对应的 Redis bucket
136+
137+ :param name: 限流标识符
138+ :return:
139+ """
140+ return await self .get (await self .wrap_item (name ))
141+
142+ async def _evict (self , now : int ) -> None :
143+ """
144+ 淘汰本地 bucket 缓存
145+
146+ :param now: 当前时间戳,单位毫秒
147+ :return:
148+ """
149+ for bucket_key , state in list (self .buckets .items ()):
150+ if now - state .last_seen <= self .cache_ttl :
151+ continue
152+ await self ._dispose (bucket_key , state , now , cleanup = True )
153+
154+ while len (self .buckets ) > self .max_cache_size :
155+ bucket_key , state = next (iter (self .buckets .items ()))
156+ await self ._dispose (bucket_key , state , now , cleanup = False )
157+
158+ async def _dispose (self , bucket_key : str , state : RedisBucketState , now : int , * , cleanup : bool ) -> None :
159+ """
160+ 移除本地 bucket 并按需清理 Redis 过期数据
161+
162+ :param bucket_key: Redis bucket key
163+ :param state: Redis bucket 缓存状态
164+ :param now: 当前时间戳,单位毫秒
165+ :param cleanup: 是否执行 Redis 过期数据清理
166+ :return:
167+ """
168+ self .buckets .pop (bucket_key , None )
169+ self .dispose (state .bucket )
170+ if cleanup :
171+ await _maybe_await (state .bucket .leak (now ))
172+ if await _maybe_await (state .bucket .count ()) == 0 :
173+ await _maybe_await (state .bucket .flush ())
174+
175+ def _bucket_key (self , name : str ) -> str :
176+ """
177+ 生成标识符对应的 Redis bucket key
178+
179+ :param name: 限流标识符
180+ :return:
181+ """
182+ digest = sha256 (name .encode ()).hexdigest ()
183+ return f'{ self .bucket_key } :{ digest } '
184+
185+ @staticmethod
186+ def _rate_key (rates : list [Rate ]) -> str :
187+ """
188+ 生成限流策略对应的 Redis key 片段
189+
190+ :param rates: pyrate_limiter Rate 对象列表
191+ :return:
192+ """
193+ value = ':' .join (f'{ rate .limit } :{ rate .interval } ' for rate in sorted (rates , key = lambda rate : rate .interval ))
194+ return sha256 (value .encode ()).hexdigest ()
21195
22196
23197def default_identifier (request : Request ) -> str :
@@ -68,23 +242,32 @@ def __init__(
68242 :param callback: 自定义限流回调函数
69243 :return:
70244 """
71- if not rates and bucket is None :
72- raise errors .ServerError (msg = '至少需要传入一个 Rate 或 bucket 实例' )
245+ if limiter is None and not rates and bucket is None :
246+ raise errors .ServerError (msg = '至少需要传入一个 Rate、bucket 或 limiter 实例' )
73247 self .rates = list (rates )
74248 self .identifier = identifier
75249 self .bucket = bucket
76250 self .limiter = limiter
77251 self .callback = callback
252+ self .bucket_factory : RedisBucketFactory | None = None
78253
79254 async def __call__ (self , request : Request , response : Response ) -> None :
255+ """
256+ 执行请求限流检查
257+
258+ :param request: FastAPI 请求对象
259+ :param response: FastAPI 响应对象
260+ :return:
261+ """
80262 if self .limiter is None :
81263 if self .bucket is None :
82- self .bucket = await RedisBucket . init ( # type: ignore
264+ self .bucket_factory = RedisBucketFactory (
83265 rates = self .rates ,
84- redis = redis_client ,
85266 bucket_key = f'{ settings .REQUEST_LIMITER_REDIS_PREFIX } ' ,
86267 )
87- self .limiter = Limiter (self .bucket )
268+ self .limiter = Limiter (self .bucket_factory )
269+ else :
270+ self .limiter = Limiter (self .bucket )
88271
89272 if is_async_callable (self .identifier ):
90273 identifier = await self .identifier (request )
@@ -93,8 +276,35 @@ async def __call__(self, request: Request, response: Response) -> None:
93276
94277 acquired = await self .limiter .try_acquire_async (identifier , blocking = False )
95278 if not acquired :
96- retry_after = ceil ( self .bucket . failing_rate . interval / 1000 )
279+ retry_after = await self ._retry_after ( identifier )
97280 if is_async_callable (self .callback ):
98281 await self .callback (request , response , retry_after )
99282 else :
100283 await run_in_threadpool (self .callback , request , response , retry_after )
284+
285+ async def _retry_after (self , identifier : str ) -> int :
286+ """
287+ 计算限流重试等待时间
288+
289+ :param identifier: 限流标识符
290+ :return:
291+ """
292+ if self .bucket_factory is not None :
293+ failing_rate = (await self .bucket_factory .get_bucket (identifier )).failing_rate
294+ elif self .bucket is not None :
295+ failing_rate = self .bucket .failing_rate
296+ else :
297+ failing_rate = None
298+
299+ if failing_rate is not None :
300+ return ceil (failing_rate .interval / 1000 )
301+
302+ if self .limiter is not None :
303+ for bucket in self .limiter .buckets ():
304+ if bucket .failing_rate is not None :
305+ return ceil (bucket .failing_rate .interval / 1000 )
306+
307+ if self .rates :
308+ return ceil (max (rate .interval for rate in self .rates ) / 1000 )
309+
310+ return 1
0 commit comments