Skip to content

Commit 5edcbfe

Browse files
committed
refactor: encapsulate rate limiter state into registry, add TTL
1 parent 11667dd commit 5edcbfe

2 files changed

Lines changed: 109 additions & 60 deletions

File tree

astrbot/dashboard/server.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(self, capacity: int, refill_rate: float):
6161
self.refill_rate = refill_rate
6262
self.tokens = float(capacity)
6363
self.last_refill = time.monotonic()
64+
self.last_accessed = time.monotonic()
6465
self.lock = asyncio.Lock()
6566

6667
async def acquire(self) -> bool:
@@ -69,13 +70,51 @@ async def acquire(self) -> bool:
6970
elapsed = now - self.last_refill
7071
self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
7172
self.last_refill = now
73+
self.last_accessed = now
7274
if self.tokens >= 1:
7375
self.tokens -= 1
7476
return True
7577
return False
7678

7779

78-
_rate_limiters: dict[str, _AuthRateLimiter] = {}
80+
class _RateLimiterRegistry:
81+
"""Per-IP token-bucket rate limiter registry. Idle entries expire after 1 hour."""
82+
83+
_ENTRY_TTL: float = 3600.0
84+
_INTERVAL: float = 1800.0
85+
86+
def __init__(self) -> None:
87+
self._limiters: dict[str, _AuthRateLimiter] = {}
88+
self._last_eviction = time.monotonic()
89+
90+
def get_or_create(
91+
self, key: str, capacity: int, refill_rate: float
92+
) -> _AuthRateLimiter:
93+
self._evict_expired()
94+
limiter = self._limiters.get(key)
95+
if limiter is None:
96+
limiter = _AuthRateLimiter(capacity=capacity, refill_rate=refill_rate)
97+
self._limiters[key] = limiter
98+
return limiter
99+
100+
def _evict_expired(self) -> None:
101+
now = time.monotonic()
102+
if now - self._last_eviction < self._INTERVAL:
103+
return
104+
self._last_eviction = now
105+
cutoff = now - self._ENTRY_TTL
106+
stale = [k for k, v in self._limiters.items() if v.last_accessed < cutoff]
107+
for k in stale:
108+
del self._limiters[k]
109+
110+
def clear(self) -> None:
111+
self._limiters.clear()
112+
113+
def __len__(self) -> int:
114+
return len(self._limiters)
115+
116+
def __contains__(self, key: str) -> bool:
117+
return key in self._limiters
79118

80119

81120
class _AddrWithPort(Protocol):
@@ -215,6 +254,7 @@ def __init__(
215254
# Fall back to expected user path (will fail gracefully later)
216255
self.data_path = os.path.abspath(user_dist)
217256

257+
self._rate_limiter_registry = _RateLimiterRegistry()
218258
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
219259
APP = self.app # noqa
220260
self.app.config["MAX_CONTENT_LENGTH"] = (
@@ -344,12 +384,9 @@ async def auth_middleware(self):
344384
max_burst = 3
345385
refill_rate = 1.0 / average_interval
346386
client_ip = self._get_request_client_ip(request)
347-
limiter = _rate_limiters.get(client_ip)
348-
if limiter is None:
349-
limiter = _AuthRateLimiter(
350-
capacity=max_burst, refill_rate=refill_rate
351-
)
352-
_rate_limiters[client_ip] = limiter
387+
limiter = self._rate_limiter_registry.get_or_create(
388+
client_ip, capacity=max_burst, refill_rate=refill_rate
389+
)
353390
if not await limiter.acquire():
354391
r = jsonify(
355392
Response()

tests/test_dashboard.py

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from quart import Quart, jsonify
1919
from werkzeug.datastructures import FileStorage
2020

21-
import astrbot.dashboard.server as dashboard_server
2221
from astrbot.core import LogBroker
2322
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
2423
from astrbot.core.db.sqlite import SQLiteDatabase
@@ -208,6 +207,7 @@ def app(core_lifecycle_td: AstrBotCoreLifecycle):
208207
shutdown_event = asyncio.Event()
209208
# The db instance is already part of the core_lifecycle_td
210209
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
210+
server.app._dashboard_server = server # expose for test cleanup
211211
return server.app
212212

213213

@@ -363,34 +363,35 @@ async def test_auth_login_secure_cookie_override(
363363

364364

365365
@pytest.mark.asyncio
366-
async def test_auth_rate_limit_uses_client_ip_bucket_across_paths(
366+
async def test_auth_rate_limit_uses_same_bucket_across_paths(
367367
app: Quart,
368368
core_lifecycle_td: AstrBotCoreLifecycle,
369369
monkeypatch: pytest.MonkeyPatch,
370370
):
371+
"""Same client IP shares a rate-limit bucket across different auth endpoints."""
371372
monkeypatch.setenv("ASTRBOT_TEST_MODE", "false")
372-
dashboard_server._rate_limiters.clear()
373-
original_value = core_lifecycle_td.astrbot_config["dashboard"].get(
374-
"trust_proxy_headers", False
375-
)
376-
core_lifecycle_td.astrbot_config["dashboard"]["trust_proxy_headers"] = True
373+
app._dashboard_server._rate_limiter_registry.clear()
374+
cfg = core_lifecycle_td.astrbot_config["dashboard"]
375+
rl_original = cfg.get("auth_rate_limit", {})
376+
tp_original = cfg.get("trust_proxy_headers", False)
377+
cfg["auth_rate_limit"] = {"enable": True, "average_interval": 3600.0, "max_burst": 1}
378+
cfg["trust_proxy_headers"] = True
377379

378380
try:
379-
test_client = app.test_client()
380-
headers = {"X-Forwarded-For": "198.51.100.10"}
381-
await test_client.post(
382-
"/api/auth/login",
383-
json={"username": "wrong", "password": "wrong"},
384-
headers=headers,
381+
client = app.test_client()
382+
h = {"X-Forwarded-For": "198.51.100.10"}
383+
r1 = await client.post(
384+
"/api/auth/login", json={"username": "u", "password": "p"}, headers=h
385385
)
386-
await test_client.post("/api/auth/totp/setup", json={}, headers=headers)
386+
assert r1.status_code != 429, "first request from IP should not be rate limited"
387387

388-
assert len(dashboard_server._rate_limiters) == 1
389-
assert "198.51.100.10" in dashboard_server._rate_limiters
390-
finally:
391-
core_lifecycle_td.astrbot_config["dashboard"]["trust_proxy_headers"] = (
392-
original_value
388+
r2 = await client.post("/api/auth/totp/setup", json={}, headers=h)
389+
assert r2.status_code == 429, (
390+
"second request from same IP should be rate limited"
393391
)
392+
finally:
393+
cfg["auth_rate_limit"] = rl_original
394+
cfg["trust_proxy_headers"] = tp_original
394395

395396

396397
@pytest.mark.asyncio
@@ -399,33 +400,42 @@ async def test_auth_rate_limit_separates_different_client_ips(
399400
core_lifecycle_td: AstrBotCoreLifecycle,
400401
monkeypatch: pytest.MonkeyPatch,
401402
):
403+
"""Different client IPs have independent rate-limit buckets."""
402404
monkeypatch.setenv("ASTRBOT_TEST_MODE", "false")
403-
dashboard_server._rate_limiters.clear()
404-
original_value = core_lifecycle_td.astrbot_config["dashboard"].get(
405-
"trust_proxy_headers", False
406-
)
407-
core_lifecycle_td.astrbot_config["dashboard"]["trust_proxy_headers"] = True
405+
app._dashboard_server._rate_limiter_registry.clear()
406+
cfg = core_lifecycle_td.astrbot_config["dashboard"]
407+
rl_original = cfg.get("auth_rate_limit", {})
408+
tp_original = cfg.get("trust_proxy_headers", False)
409+
cfg["auth_rate_limit"] = {"enable": True, "average_interval": 3600.0, "max_burst": 1}
410+
cfg["trust_proxy_headers"] = True
408411

409412
try:
410-
test_client = app.test_client()
411-
await test_client.post(
413+
client = app.test_client()
414+
r_a = await client.post(
412415
"/api/auth/login",
413-
json={"username": "wrong", "password": "wrong"},
416+
json={"username": "u", "password": "p"},
414417
headers={"X-Forwarded-For": "198.51.100.10"},
415418
)
416-
await test_client.post(
419+
assert r_a.status_code != 429
420+
421+
r_b = await client.post(
417422
"/api/auth/login",
418-
json={"username": "wrong", "password": "wrong"},
419-
headers={"X-Forwarded-For": "198.51.100.11"},
423+
json={"username": "u", "password": "p"},
424+
headers={"X-Forwarded-For": "198.51.100.10"},
425+
)
426+
assert r_b.status_code == 429, (
427+
"second request from same IP should be rate limited"
420428
)
421429

422-
assert len(dashboard_server._rate_limiters) == 2
423-
assert "198.51.100.10" in dashboard_server._rate_limiters
424-
assert "198.51.100.11" in dashboard_server._rate_limiters
425-
finally:
426-
core_lifecycle_td.astrbot_config["dashboard"]["trust_proxy_headers"] = (
427-
original_value
430+
r_c = await client.post(
431+
"/api/auth/login",
432+
json={"username": "u", "password": "p"},
433+
headers={"X-Forwarded-For": "198.51.100.11"},
428434
)
435+
assert r_c.status_code != 429, "different IP has its own bucket"
436+
finally:
437+
cfg["auth_rate_limit"] = rl_original
438+
cfg["trust_proxy_headers"] = tp_original
429439

430440

431441
@pytest.mark.asyncio
@@ -434,33 +444,35 @@ async def test_auth_rate_limit_ignores_proxy_headers_by_default(
434444
core_lifecycle_td: AstrBotCoreLifecycle,
435445
monkeypatch: pytest.MonkeyPatch,
436446
):
447+
"""When trust_proxy_headers is False, all proxy-spoofed IPs fall back to the connection IP."""
437448
monkeypatch.setenv("ASTRBOT_TEST_MODE", "false")
438-
dashboard_server._rate_limiters.clear()
439-
original_value = core_lifecycle_td.astrbot_config["dashboard"].get(
440-
"trust_proxy_headers", False
441-
)
442-
core_lifecycle_td.astrbot_config["dashboard"]["trust_proxy_headers"] = False
449+
app._dashboard_server._rate_limiter_registry.clear()
450+
cfg = core_lifecycle_td.astrbot_config["dashboard"]
451+
rl_original = cfg.get("auth_rate_limit", {})
452+
tp_original = cfg.get("trust_proxy_headers", False)
453+
cfg["auth_rate_limit"] = {"enable": True, "average_interval": 3600.0, "max_burst": 1}
454+
cfg["trust_proxy_headers"] = False
443455

444456
try:
445-
test_client = app.test_client()
446-
await test_client.post(
457+
client = app.test_client()
458+
r1 = await client.post(
447459
"/api/auth/login",
448-
json={"username": "wrong", "password": "wrong"},
460+
json={"username": "u", "password": "p"},
449461
headers={"X-Forwarded-For": "198.51.100.20"},
450462
)
451-
await test_client.post(
463+
assert r1.status_code != 429
464+
465+
r2 = await client.post(
452466
"/api/auth/login",
453-
json={"username": "wrong", "password": "wrong"},
467+
json={"username": "u", "password": "p"},
454468
headers={"X-Forwarded-For": "198.51.100.21"},
455469
)
456-
457-
assert len(dashboard_server._rate_limiters) == 1
458-
assert "198.51.100.20" not in dashboard_server._rate_limiters
459-
assert "198.51.100.21" not in dashboard_server._rate_limiters
460-
finally:
461-
core_lifecycle_td.astrbot_config["dashboard"]["trust_proxy_headers"] = (
462-
original_value
470+
assert r2.status_code == 429, (
471+
"same connection IP, same bucket despite proxy headers"
463472
)
473+
finally:
474+
cfg["auth_rate_limit"] = rl_original
475+
cfg["trust_proxy_headers"] = tp_original
464476

465477

466478
@pytest.mark.asyncio

0 commit comments

Comments
 (0)