Skip to content

Commit 11667dd

Browse files
committed
fix: normalize dashboard client IP from trusted proxy headers
1 parent 58da370 commit 11667dd

1 file changed

Lines changed: 73 additions & 12 deletions

File tree

astrbot/dashboard/server.py

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import hashlib
3+
import ipaddress
34
import logging
45
import os
56
import socket
@@ -13,6 +14,8 @@
1314
from flask.json.provider import DefaultJSONProvider
1415
from hypercorn.asyncio import serve
1516
from hypercorn.config import Config as HyperConfig
17+
from hypercorn.logging import AccessLogAtoms
18+
from hypercorn.logging import Logger as HypercornLogger
1619
from quart import Quart, g, jsonify, request
1720
from quart.logging import default_handler
1821
from werkzeug.exceptions import MethodNotAllowed, NotFound
@@ -125,6 +128,53 @@ def _parse_env_bool(value: str | None, default: bool) -> bool:
125128
return value.strip().lower() in {"1", "true", "yes", "on"}
126129

127130

131+
class _ProxyAwareHypercornLogger(HypercornLogger):
132+
@staticmethod
133+
def _get_request_log_host(request_scope) -> str | None:
134+
forwarded_for = None
135+
real_ip = None
136+
for raw_name, raw_value in request_scope.get("headers", []):
137+
header_name = raw_name.decode("latin1").lower()
138+
if header_name == "x-forwarded-for":
139+
forwarded_for = raw_value.decode("latin1")
140+
elif header_name == "x-real-ip":
141+
real_ip = raw_value.decode("latin1")
142+
143+
if forwarded_for is not None and real_ip is not None:
144+
break
145+
146+
forwarded_for = str(forwarded_for or "").strip()
147+
if forwarded_for:
148+
first_ip = forwarded_for.split(",", 1)[0].strip()
149+
if first_ip and first_ip.lower() != "unknown":
150+
try:
151+
return str(ipaddress.ip_address(first_ip))
152+
except ValueError:
153+
pass
154+
155+
real_ip = str(real_ip or "").strip()
156+
if real_ip and real_ip.lower() != "unknown":
157+
try:
158+
return str(ipaddress.ip_address(real_ip))
159+
except ValueError:
160+
pass
161+
162+
client = request_scope.get("client")
163+
if not client:
164+
return None
165+
host = str(client[0]).strip()
166+
if host:
167+
return host
168+
return None
169+
170+
def atoms(self, request, response, request_time):
171+
atoms = AccessLogAtoms(request, response, request_time)
172+
client_host = self._get_request_log_host(request)
173+
if client_host:
174+
atoms["h"] = client_host
175+
return atoms
176+
177+
128178
class AstrBotJSONProvider(DefaultJSONProvider):
129179
def default(self, obj):
130180
if isinstance(obj, datetime):
@@ -293,7 +343,7 @@ async def auth_middleware(self):
293343
if max_burst <= 0:
294344
max_burst = 3
295345
refill_rate = 1.0 / average_interval
296-
client_ip = self._get_request_client_ip()
346+
client_ip = self._get_request_client_ip(request)
297347
limiter = _rate_limiters.get(client_ip)
298348
if limiter is None:
299349
limiter = _AuthRateLimiter(
@@ -358,24 +408,33 @@ async def auth_middleware(self):
358408
r.status_code = 401
359409
return r
360410

361-
def _get_request_client_ip(self) -> str:
362-
trust_proxy_headers = bool(
363-
self.config.get("dashboard", {}).get("trust_proxy_headers", False)
364-
)
365-
if trust_proxy_headers:
366-
forwarded_for = request.headers.get("X-Forwarded-For", "").strip()
411+
def _get_request_client_ip(self, current_request) -> str:
412+
if bool(self.config.get("dashboard", {}).get("trust_proxy_headers", False)):
413+
forwarded_for = str(
414+
current_request.headers.get("X-Forwarded-For", "")
415+
).strip()
367416
if forwarded_for:
368417
first_ip = forwarded_for.split(",", 1)[0].strip()
369418
if first_ip and first_ip.lower() != "unknown":
370-
return first_ip
419+
try:
420+
return str(ipaddress.ip_address(first_ip))
421+
except ValueError:
422+
pass
371423

372-
real_ip = request.headers.get("X-Real-IP", "").strip()
424+
real_ip = str(current_request.headers.get("X-Real-IP", "")).strip()
373425
if real_ip and real_ip.lower() != "unknown":
374-
return real_ip
426+
try:
427+
return str(ipaddress.ip_address(real_ip))
428+
except ValueError:
429+
pass
375430

376-
remote_addr = request.remote_addr
431+
remote_addr = str(current_request.remote_addr or "").strip()
377432
if remote_addr:
378-
return str(remote_addr)
433+
try:
434+
return str(ipaddress.ip_address(remote_addr))
435+
except ValueError:
436+
pass
437+
379438
return "unknown"
380439

381440
@staticmethod
@@ -613,6 +672,8 @@ def run(self):
613672
# 配置 Hypercorn
614673
config = HyperConfig()
615674
config.bind = [f"{host}:{port}"]
675+
if bool(self.config.get("dashboard", {}).get("trust_proxy_headers", False)):
676+
config.logger_class = _ProxyAwareHypercornLogger
616677
if ssl_enable:
617678
config.certfile = resolved_ssl_config["certfile"]
618679
config.keyfile = resolved_ssl_config["keyfile"]

0 commit comments

Comments
 (0)