diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index 8bf92ef4d5..5c8dd7d79b 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -1,7 +1,6 @@ -import httpx from openai import AsyncOpenAI -from astrbot import logger +from astrbot.core.utils.network_utils import create_proxy_client from ..entities import ProviderType from ..provider import EmbeddingProvider @@ -19,10 +18,7 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: self.provider_config = provider_config self.provider_settings = provider_settings proxy = provider_config.get("proxy", "") - http_client = None - if proxy: - logger.info(f"[OpenAI Embedding] 使用代理: {proxy}") - http_client = httpx.AsyncClient(proxy=proxy) + http_client = create_proxy_client("OpenAI Embedding", proxy) api_base = provider_config.get("embedding_api_base", "").strip() if not api_base: api_base = "https://api.openai.com/v1" diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py index 217b189251..6527f3eb73 100644 --- a/astrbot/core/provider/sources/openai_tts_api_source.py +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -1,11 +1,10 @@ import os import uuid -import httpx from openai import NOT_GIVEN, AsyncOpenAI -from astrbot import logger from astrbot.core.utils.astrbot_path import get_astrbot_temp_path +from astrbot.core.utils.network_utils import create_proxy_client from ..entities import ProviderType from ..provider import TTSProvider @@ -32,10 +31,7 @@ def __init__( timeout = int(timeout) proxy = provider_config.get("proxy", "") - http_client = None - if proxy: - logger.info(f"[OpenAI TTS] 使用代理: {proxy}") - http_client = httpx.AsyncClient(proxy=proxy) + http_client = create_proxy_client("OpenAI TTS", proxy) self.client = AsyncOpenAI( api_key=self.chosen_api_key, base_url=provider_config.get("api_base"), diff --git a/astrbot/core/utils/network_utils.py b/astrbot/core/utils/network_utils.py index 727f3762ae..a9654a2418 100644 --- a/astrbot/core/utils/network_utils.py +++ b/astrbot/core/utils/network_utils.py @@ -72,14 +72,66 @@ def log_connection_failure( ) if effective_proxy: + sanitized_proxy = _sanitize_proxy_url(effective_proxy) + error_text = str(error) + if effective_proxy: + error_text = error_text.replace(effective_proxy, sanitized_proxy) logger.error( f"[{provider_label}] 网络/代理连接失败 ({error_type})。" - f"代理地址: {effective_proxy},错误: {error}" + f"代理地址: {sanitized_proxy},错误: {error_text}" ) else: logger.error(f"[{provider_label}] 网络连接失败 ({error_type})。错误: {error}") +def _is_socks_proxy(proxy: str) -> bool: + """Check if the proxy URL is a SOCKS proxy. + + Args: + proxy: The proxy URL string + + Returns: + True if the proxy is a SOCKS proxy (socks4://, socks5://, socks5h://) + """ + return proxy.lower().startswith(("socks4://", "socks5://", "socks5h://")) + + +def _sanitize_proxy_url(proxy: str) -> str: + """Sanitize proxy URL by masking credentials for safe logging. + + Args: + proxy: The proxy URL string + + Returns: + Sanitized proxy URL with credentials masked (e.g., "http://****@host:port") + """ + try: + from urllib.parse import urlparse, urlunparse + + parsed = urlparse(proxy) + # Any userinfo in netloc should be masked to avoid leaking tokens/passwords. + if "@" in parsed.netloc and parsed.hostname: + host = parsed.hostname + if ":" in host and not host.startswith("["): + host = f"[{host}]" + netloc = f"****@{host}" + if parsed.port: + netloc += f":{parsed.port}" + return urlunparse( + ( + parsed.scheme, + netloc, + parsed.path, + parsed.params, + parsed.query, + parsed.fragment, + ) + ) + except Exception: + return "****" + return proxy + + def create_proxy_client( provider_label: str, proxy: str | None = None, @@ -95,8 +147,27 @@ def create_proxy_client( Returns: An httpx.AsyncClient configured with the proxy, or None if no proxy + + Raises: + ImportError: If SOCKS proxy is used but socksio is not installed """ - if proxy: - logger.info(f"[{provider_label}] 使用代理: {proxy}") - return httpx.AsyncClient(proxy=proxy) - return None + if not proxy: + return None + + sanitized_proxy = _sanitize_proxy_url(proxy) + logger.info(f"[{provider_label}] 使用代理: {sanitized_proxy}") + + # Check for SOCKS proxy and provide helpful error if socksio is not installed + if _is_socks_proxy(proxy): + try: + import socksio # noqa: F401 + except ImportError: + raise ImportError( + f"使用 SOCKS 代理需要安装 socksio 包。请运行以下命令安装:\n" + f" pip install 'httpx[socks]'\n" + f"或者:\n" + f" pip install socksio\n" + f"代理地址: {sanitized_proxy}" + ) from None + + return httpx.AsyncClient(proxy=proxy) diff --git a/pyproject.toml b/pyproject.toml index 463da49556..9f27454549 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ dependencies = [ "shipyard-python-sdk>=0.2.4", "shipyard-neo-sdk>=0.2.0", "python-socks>=2.8.0", + "socksio>=1.0.0", "packaging>=24.2", ] diff --git a/tests/unit/test_network_utils.py b/tests/unit/test_network_utils.py new file mode 100644 index 0000000000..6504114b25 --- /dev/null +++ b/tests/unit/test_network_utils.py @@ -0,0 +1,118 @@ +"""Tests for network utility helpers.""" + +import builtins + +from astrbot.core.utils import network_utils + + +def test_sanitize_proxy_url_masks_password_credentials(): + proxy = "http://user:secret@127.0.0.1:1080" + assert network_utils._sanitize_proxy_url(proxy) == "http://****@127.0.0.1:1080" + + +def test_sanitize_proxy_url_masks_username_only_credentials(): + proxy = "http://token@127.0.0.1:1080" + assert network_utils._sanitize_proxy_url(proxy) == "http://****@127.0.0.1:1080" + + +def test_sanitize_proxy_url_masks_empty_password_credentials(): + proxy = "http://user:@127.0.0.1:1080" + assert network_utils._sanitize_proxy_url(proxy) == "http://****@127.0.0.1:1080" + + +def test_sanitize_proxy_url_returns_original_when_no_credentials(): + proxy = "http://127.0.0.1:1080" + assert network_utils._sanitize_proxy_url(proxy) == proxy + + +def test_sanitize_proxy_url_returns_original_for_non_url_text(): + proxy = "not a url" + assert network_utils._sanitize_proxy_url(proxy) == proxy + + +def test_sanitize_proxy_url_returns_original_for_empty_string(): + assert network_utils._sanitize_proxy_url("") == "" + + +def test_sanitize_proxy_url_masks_credentials_for_ipv6_host(): + proxy = "http://user:secret@[::1]:1080" + assert network_utils._sanitize_proxy_url(proxy) == "http://****@[::1]:1080" + + +def test_sanitize_proxy_url_falls_back_to_placeholder_on_parse_error(monkeypatch): + proxy = "http://user:secret@127.0.0.1:1080" + original_import = builtins.__import__ + + def guarded_import(name, globals_=None, locals_=None, fromlist=(), level=0): + if name == "urllib.parse": + raise ImportError("boom") + return original_import(name, globals_, locals_, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", guarded_import) + + assert network_utils._sanitize_proxy_url(proxy) == "****" + + +def test_is_socks_proxy_detects_supported_schemes(): + assert network_utils._is_socks_proxy("socks5://127.0.0.1:1080") + assert network_utils._is_socks_proxy("socks4://127.0.0.1:1080") + assert network_utils._is_socks_proxy("socks5h://127.0.0.1:1080") + assert not network_utils._is_socks_proxy("http://127.0.0.1:1080") + + +def test_log_connection_failure_redacts_proxy_in_error_text(monkeypatch): + proxy = "http://token@127.0.0.1:1080" + captured = {} + + def fake_error(message: str): + captured["message"] = message + + monkeypatch.setattr(network_utils.logger, "error", fake_error) + + network_utils.log_connection_failure( + provider_label="OpenAI", + error=RuntimeError(f"proxy connect failed: {proxy}"), + proxy=proxy, + ) + + assert "http://token@127.0.0.1:1080" not in captured["message"] + assert "http://****@127.0.0.1:1080" in captured["message"] + + +def test_log_connection_failure_without_proxy_does_not_log_proxy_label(monkeypatch): + captured = {} + + def fake_error(message: str): + captured["message"] = message + + monkeypatch.setattr(network_utils.logger, "error", fake_error) + monkeypatch.delenv("http_proxy", raising=False) + monkeypatch.delenv("https_proxy", raising=False) + + network_utils.log_connection_failure( + provider_label="OpenAI", + error=RuntimeError("connection failed"), + proxy=None, + ) + + assert "代理地址" not in captured["message"] + assert "connection failed" in captured["message"] + + +def test_log_connection_failure_keeps_error_text_when_no_proxy_text(monkeypatch): + proxy = "http://token@127.0.0.1:1080" + captured = {} + + def fake_error(message: str): + captured["message"] = message + + monkeypatch.setattr(network_utils.logger, "error", fake_error) + + network_utils.log_connection_failure( + provider_label="OpenAI", + error=RuntimeError("connect timeout"), + proxy=proxy, + ) + + assert "http://****@127.0.0.1:1080" in captured["message"] + assert "connect timeout" in captured["message"]