|
1 | 1 | import asyncio |
2 | 2 | import hashlib |
| 3 | +import ipaddress |
3 | 4 | import logging |
4 | 5 | import os |
5 | 6 | import socket |
|
13 | 14 | from flask.json.provider import DefaultJSONProvider |
14 | 15 | from hypercorn.asyncio import serve |
15 | 16 | from hypercorn.config import Config as HyperConfig |
| 17 | +from hypercorn.logging import AccessLogAtoms |
| 18 | +from hypercorn.logging import Logger as HypercornLogger |
16 | 19 | from quart import Quart, g, jsonify, request |
17 | 20 | from quart.logging import default_handler |
18 | 21 | from werkzeug.exceptions import MethodNotAllowed, NotFound |
@@ -125,6 +128,53 @@ def _parse_env_bool(value: str | None, default: bool) -> bool: |
125 | 128 | return value.strip().lower() in {"1", "true", "yes", "on"} |
126 | 129 |
|
127 | 130 |
|
| 131 | +class _ProxyAwareHypercornLogger(HypercornLogger): |
| 132 | + @staticmethod |
| 133 | + def _get_request_log_host(request_scope) -> str | None: |
| 134 | + forwarded_for = None |
| 135 | + real_ip = None |
| 136 | + for raw_name, raw_value in request_scope.get("headers", []): |
| 137 | + header_name = raw_name.decode("latin1").lower() |
| 138 | + if header_name == "x-forwarded-for": |
| 139 | + forwarded_for = raw_value.decode("latin1") |
| 140 | + elif header_name == "x-real-ip": |
| 141 | + real_ip = raw_value.decode("latin1") |
| 142 | + |
| 143 | + if forwarded_for is not None and real_ip is not None: |
| 144 | + break |
| 145 | + |
| 146 | + forwarded_for = str(forwarded_for or "").strip() |
| 147 | + if forwarded_for: |
| 148 | + first_ip = forwarded_for.split(",", 1)[0].strip() |
| 149 | + if first_ip and first_ip.lower() != "unknown": |
| 150 | + try: |
| 151 | + return str(ipaddress.ip_address(first_ip)) |
| 152 | + except ValueError: |
| 153 | + pass |
| 154 | + |
| 155 | + real_ip = str(real_ip or "").strip() |
| 156 | + if real_ip and real_ip.lower() != "unknown": |
| 157 | + try: |
| 158 | + return str(ipaddress.ip_address(real_ip)) |
| 159 | + except ValueError: |
| 160 | + pass |
| 161 | + |
| 162 | + client = request_scope.get("client") |
| 163 | + if not client: |
| 164 | + return None |
| 165 | + host = str(client[0]).strip() |
| 166 | + if host: |
| 167 | + return host |
| 168 | + return None |
| 169 | + |
| 170 | + def atoms(self, request, response, request_time): |
| 171 | + atoms = AccessLogAtoms(request, response, request_time) |
| 172 | + client_host = self._get_request_log_host(request) |
| 173 | + if client_host: |
| 174 | + atoms["h"] = client_host |
| 175 | + return atoms |
| 176 | + |
| 177 | + |
128 | 178 | class AstrBotJSONProvider(DefaultJSONProvider): |
129 | 179 | def default(self, obj): |
130 | 180 | if isinstance(obj, datetime): |
@@ -293,7 +343,7 @@ async def auth_middleware(self): |
293 | 343 | if max_burst <= 0: |
294 | 344 | max_burst = 3 |
295 | 345 | refill_rate = 1.0 / average_interval |
296 | | - client_ip = self._get_request_client_ip() |
| 346 | + client_ip = self._get_request_client_ip(request) |
297 | 347 | limiter = _rate_limiters.get(client_ip) |
298 | 348 | if limiter is None: |
299 | 349 | limiter = _AuthRateLimiter( |
@@ -358,24 +408,33 @@ async def auth_middleware(self): |
358 | 408 | r.status_code = 401 |
359 | 409 | return r |
360 | 410 |
|
361 | | - def _get_request_client_ip(self) -> str: |
362 | | - trust_proxy_headers = bool( |
363 | | - self.config.get("dashboard", {}).get("trust_proxy_headers", False) |
364 | | - ) |
365 | | - if trust_proxy_headers: |
366 | | - forwarded_for = request.headers.get("X-Forwarded-For", "").strip() |
| 411 | + def _get_request_client_ip(self, current_request) -> str: |
| 412 | + if bool(self.config.get("dashboard", {}).get("trust_proxy_headers", False)): |
| 413 | + forwarded_for = str( |
| 414 | + current_request.headers.get("X-Forwarded-For", "") |
| 415 | + ).strip() |
367 | 416 | if forwarded_for: |
368 | 417 | first_ip = forwarded_for.split(",", 1)[0].strip() |
369 | 418 | if first_ip and first_ip.lower() != "unknown": |
370 | | - return first_ip |
| 419 | + try: |
| 420 | + return str(ipaddress.ip_address(first_ip)) |
| 421 | + except ValueError: |
| 422 | + pass |
371 | 423 |
|
372 | | - real_ip = request.headers.get("X-Real-IP", "").strip() |
| 424 | + real_ip = str(current_request.headers.get("X-Real-IP", "")).strip() |
373 | 425 | if real_ip and real_ip.lower() != "unknown": |
374 | | - return real_ip |
| 426 | + try: |
| 427 | + return str(ipaddress.ip_address(real_ip)) |
| 428 | + except ValueError: |
| 429 | + pass |
375 | 430 |
|
376 | | - remote_addr = request.remote_addr |
| 431 | + remote_addr = str(current_request.remote_addr or "").strip() |
377 | 432 | if remote_addr: |
378 | | - return str(remote_addr) |
| 433 | + try: |
| 434 | + return str(ipaddress.ip_address(remote_addr)) |
| 435 | + except ValueError: |
| 436 | + pass |
| 437 | + |
379 | 438 | return "unknown" |
380 | 439 |
|
381 | 440 | @staticmethod |
@@ -613,6 +672,8 @@ def run(self): |
613 | 672 | # 配置 Hypercorn |
614 | 673 | config = HyperConfig() |
615 | 674 | config.bind = [f"{host}:{port}"] |
| 675 | + if bool(self.config.get("dashboard", {}).get("trust_proxy_headers", False)): |
| 676 | + config.logger_class = _ProxyAwareHypercornLogger |
616 | 677 | if ssl_enable: |
617 | 678 | config.certfile = resolved_ssl_config["certfile"] |
618 | 679 | config.keyfile = resolved_ssl_config["keyfile"] |
|
0 commit comments