Skip to content

Commit 0441907

Browse files
authored
Fix redis rate limiter bucket routing (#1212)
* Fix redis rate limiter bucket routing * Add rate limiter function docstrings * Optimize Redis rate limiter time source
1 parent 08ee387 commit 0441907

1 file changed

Lines changed: 218 additions & 8 deletions

File tree

backend/utils/limiter.py

Lines changed: 218 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
from asyncio import Lock
2+
from collections import OrderedDict
13
from collections.abc import Awaitable, Callable
4+
from dataclasses import dataclass
5+
from hashlib import sha256
6+
from inspect import isawaitable
27
from math import ceil
3-
from typing import TypeAlias
8+
from typing import TypeAlias, TypeVar
49

510
from fastapi import Request, Response
611
from fastapi_pagination.utils import is_async_callable
7-
from pyrate_limiter import AbstractBucket, Limiter, Rate
12+
from pyrate_limiter import AbstractBucket, BucketFactory, Limiter, Rate, RateItem
813
from pyrate_limiter.buckets import RedisBucket
14+
from redis.asyncio import Redis
915
from starlette.concurrency import run_in_threadpool
1016

1117
from backend.common.exception import errors
@@ -18,6 +24,174 @@
1824
CallbackCallable: TypeAlias = (
1925
Callable[[Request, Response, int], None] | Callable[[Request, Response, int], Awaitable[None]]
2026
)
27+
T = TypeVar('T')
28+
29+
REQUEST_LIMITER_BUCKET_CACHE_MAX_SIZE = 4096
30+
REQUEST_LIMITER_BUCKET_CACHE_BUFFER_MS = 10_000
31+
32+
33+
@dataclass(slots=True)
34+
class RedisBucketState:
35+
"""Redis bucket 缓存状态"""
36+
37+
bucket: RedisBucket
38+
last_seen: int
39+
40+
41+
async def _maybe_await(value: T | Awaitable[T]) -> T:
42+
"""
43+
兼容同步值和异步值
44+
45+
:param value: 同步值或 Awaitable 对象
46+
:return:
47+
"""
48+
if isawaitable(value):
49+
return await value
50+
return value
51+
52+
53+
async def _redis_time_ms(redis: Redis) -> int:
54+
"""
55+
获取 Redis 服务端当前时间
56+
57+
:return:
58+
"""
59+
seconds, microseconds = await redis.time()
60+
return seconds * 1000 + microseconds // 1000
61+
62+
63+
class RedisTimeBucket(RedisBucket):
64+
"""使用 Redis 服务端时间的 Redis bucket"""
65+
66+
async def now(self) -> int:
67+
"""获取 Redis 服务端当前时间"""
68+
return await _redis_time_ms(self.redis)
69+
70+
71+
class RedisBucketFactory(BucketFactory):
72+
"""按请求标识符路由到独立 Redis bucket"""
73+
74+
def __init__(
75+
self,
76+
rates: list[Rate],
77+
bucket_key: str,
78+
max_cache_size: int = REQUEST_LIMITER_BUCKET_CACHE_MAX_SIZE,
79+
) -> None:
80+
"""
81+
初始化 Redis bucket 工厂
82+
83+
:param rates: pyrate_limiter Rate 对象列表
84+
:param bucket_key: Redis key 前缀
85+
:param max_cache_size: 本地 bucket 缓存最大数量
86+
:return:
87+
"""
88+
self.rates = rates
89+
self.bucket_key = f'{bucket_key}:{self._rate_key(rates)}'
90+
self.max_cache_size = max(1, max_cache_size)
91+
self.cache_ttl = max(rate.interval for rate in rates) + REQUEST_LIMITER_BUCKET_CACHE_BUFFER_MS
92+
self.lock = Lock()
93+
self.buckets: OrderedDict[str, RedisBucketState] = OrderedDict()
94+
95+
async def wrap_item(self, name: str, weight: int = 1) -> RateItem:
96+
"""
97+
包装限流项
98+
99+
:param name: 限流标识符
100+
:param weight: 请求权重
101+
:return:
102+
"""
103+
return RateItem(name, await _redis_time_ms(redis_client), weight=weight)
104+
105+
async def get(self, item: RateItem) -> RedisBucket:
106+
"""
107+
获取标识符对应的 Redis bucket
108+
109+
:param item: 限流项
110+
:return:
111+
"""
112+
bucket_key = self._bucket_key(item.name)
113+
now = await _redis_time_ms(redis_client)
114+
115+
async with self.lock:
116+
state = self.buckets.get(bucket_key)
117+
if state is not None:
118+
state.last_seen = now
119+
self.buckets.move_to_end(bucket_key)
120+
return state.bucket
121+
122+
bucket_result = RedisTimeBucket.init(
123+
rates=self.rates,
124+
redis=redis_client,
125+
bucket_key=bucket_key,
126+
)
127+
bucket = await _maybe_await(bucket_result)
128+
self.buckets[bucket_key] = RedisBucketState(bucket=bucket, last_seen=now)
129+
self.schedule_leak(bucket)
130+
await self._evict(now)
131+
return bucket
132+
133+
async def get_bucket(self, name: str) -> RedisBucket:
134+
"""
135+
获取标识符对应的 Redis bucket
136+
137+
:param name: 限流标识符
138+
:return:
139+
"""
140+
return await self.get(await self.wrap_item(name))
141+
142+
async def _evict(self, now: int) -> None:
143+
"""
144+
淘汰本地 bucket 缓存
145+
146+
:param now: 当前时间戳,单位毫秒
147+
:return:
148+
"""
149+
for bucket_key, state in list(self.buckets.items()):
150+
if now - state.last_seen <= self.cache_ttl:
151+
continue
152+
await self._dispose(bucket_key, state, now, cleanup=True)
153+
154+
while len(self.buckets) > self.max_cache_size:
155+
bucket_key, state = next(iter(self.buckets.items()))
156+
await self._dispose(bucket_key, state, now, cleanup=False)
157+
158+
async def _dispose(self, bucket_key: str, state: RedisBucketState, now: int, *, cleanup: bool) -> None:
159+
"""
160+
移除本地 bucket 并按需清理 Redis 过期数据
161+
162+
:param bucket_key: Redis bucket key
163+
:param state: Redis bucket 缓存状态
164+
:param now: 当前时间戳,单位毫秒
165+
:param cleanup: 是否执行 Redis 过期数据清理
166+
:return:
167+
"""
168+
self.buckets.pop(bucket_key, None)
169+
self.dispose(state.bucket)
170+
if cleanup:
171+
await _maybe_await(state.bucket.leak(now))
172+
if await _maybe_await(state.bucket.count()) == 0:
173+
await _maybe_await(state.bucket.flush())
174+
175+
def _bucket_key(self, name: str) -> str:
176+
"""
177+
生成标识符对应的 Redis bucket key
178+
179+
:param name: 限流标识符
180+
:return:
181+
"""
182+
digest = sha256(name.encode()).hexdigest()
183+
return f'{self.bucket_key}:{digest}'
184+
185+
@staticmethod
186+
def _rate_key(rates: list[Rate]) -> str:
187+
"""
188+
生成限流策略对应的 Redis key 片段
189+
190+
:param rates: pyrate_limiter Rate 对象列表
191+
:return:
192+
"""
193+
value = ':'.join(f'{rate.limit}:{rate.interval}' for rate in sorted(rates, key=lambda rate: rate.interval))
194+
return sha256(value.encode()).hexdigest()
21195

22196

23197
def default_identifier(request: Request) -> str:
@@ -68,23 +242,32 @@ def __init__(
68242
:param callback: 自定义限流回调函数
69243
:return:
70244
"""
71-
if not rates and bucket is None:
72-
raise errors.ServerError(msg='至少需要传入一个 Rate 或 bucket 实例')
245+
if limiter is None and not rates and bucket is None:
246+
raise errors.ServerError(msg='至少需要传入一个 Rate、bucketlimiter 实例')
73247
self.rates = list(rates)
74248
self.identifier = identifier
75249
self.bucket = bucket
76250
self.limiter = limiter
77251
self.callback = callback
252+
self.bucket_factory: RedisBucketFactory | None = None
78253

79254
async def __call__(self, request: Request, response: Response) -> None:
255+
"""
256+
执行请求限流检查
257+
258+
:param request: FastAPI 请求对象
259+
:param response: FastAPI 响应对象
260+
:return:
261+
"""
80262
if self.limiter is None:
81263
if self.bucket is None:
82-
self.bucket = await RedisBucket.init( # type: ignore
264+
self.bucket_factory = RedisBucketFactory(
83265
rates=self.rates,
84-
redis=redis_client,
85266
bucket_key=f'{settings.REQUEST_LIMITER_REDIS_PREFIX}',
86267
)
87-
self.limiter = Limiter(self.bucket)
268+
self.limiter = Limiter(self.bucket_factory)
269+
else:
270+
self.limiter = Limiter(self.bucket)
88271

89272
if is_async_callable(self.identifier):
90273
identifier = await self.identifier(request)
@@ -93,8 +276,35 @@ async def __call__(self, request: Request, response: Response) -> None:
93276

94277
acquired = await self.limiter.try_acquire_async(identifier, blocking=False)
95278
if not acquired:
96-
retry_after = ceil(self.bucket.failing_rate.interval / 1000)
279+
retry_after = await self._retry_after(identifier)
97280
if is_async_callable(self.callback):
98281
await self.callback(request, response, retry_after)
99282
else:
100283
await run_in_threadpool(self.callback, request, response, retry_after)
284+
285+
async def _retry_after(self, identifier: str) -> int:
286+
"""
287+
计算限流重试等待时间
288+
289+
:param identifier: 限流标识符
290+
:return:
291+
"""
292+
if self.bucket_factory is not None:
293+
failing_rate = (await self.bucket_factory.get_bucket(identifier)).failing_rate
294+
elif self.bucket is not None:
295+
failing_rate = self.bucket.failing_rate
296+
else:
297+
failing_rate = None
298+
299+
if failing_rate is not None:
300+
return ceil(failing_rate.interval / 1000)
301+
302+
if self.limiter is not None:
303+
for bucket in self.limiter.buckets():
304+
if bucket.failing_rate is not None:
305+
return ceil(bucket.failing_rate.interval / 1000)
306+
307+
if self.rates:
308+
return ceil(max(rate.interval for rate in self.rates) / 1000)
309+
310+
return 1

0 commit comments

Comments
 (0)