Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions astrbot/core/provider/sources/openai_embedding_source.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import httpx
from openai import AsyncOpenAI

from astrbot import logger
from astrbot.core.utils.network_utils import create_proxy_client

from ..entities import ProviderType
from ..provider import EmbeddingProvider
Expand All @@ -19,10 +18,7 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
self.provider_config = provider_config
self.provider_settings = provider_settings
proxy = provider_config.get("proxy", "")
http_client = None
if proxy:
logger.info(f"[OpenAI Embedding] 使用代理: {proxy}")
http_client = httpx.AsyncClient(proxy=proxy)
http_client = create_proxy_client("OpenAI Embedding", proxy)
api_base = provider_config.get("embedding_api_base", "").strip()
if not api_base:
api_base = "https://api.openai.com/v1"
Expand Down
8 changes: 2 additions & 6 deletions astrbot/core/provider/sources/openai_tts_api_source.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
import uuid

import httpx
from openai import NOT_GIVEN, AsyncOpenAI

from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.network_utils import create_proxy_client

from ..entities import ProviderType
from ..provider import TTSProvider
Expand All @@ -32,10 +31,7 @@ def __init__(
timeout = int(timeout)

proxy = provider_config.get("proxy", "")
http_client = None
if proxy:
logger.info(f"[OpenAI TTS] 使用代理: {proxy}")
http_client = httpx.AsyncClient(proxy=proxy)
http_client = create_proxy_client("OpenAI TTS", proxy)
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base"),
Expand Down
81 changes: 76 additions & 5 deletions astrbot/core/utils/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,66 @@ def log_connection_failure(
)

if effective_proxy:
sanitized_proxy = _sanitize_proxy_url(effective_proxy)
error_text = str(error)
if effective_proxy:
error_text = error_text.replace(effective_proxy, sanitized_proxy)
logger.error(
f"[{provider_label}] 网络/代理连接失败 ({error_type})。"
f"代理地址: {effective_proxy},错误: {error}"
f"代理地址: {sanitized_proxy},错误: {error_text}"
)
else:
logger.error(f"[{provider_label}] 网络连接失败 ({error_type})。错误: {error}")


def _is_socks_proxy(proxy: str) -> bool:
"""Check if the proxy URL is a SOCKS proxy.

Args:
proxy: The proxy URL string

Returns:
True if the proxy is a SOCKS proxy (socks4://, socks5://, socks5h://)
"""
return proxy.lower().startswith(("socks4://", "socks5://", "socks5h://"))


def _sanitize_proxy_url(proxy: str) -> str:
"""Sanitize proxy URL by masking credentials for safe logging.

Args:
proxy: The proxy URL string

Returns:
Sanitized proxy URL with credentials masked (e.g., "http://****@host:port")
"""
try:
from urllib.parse import urlparse, urlunparse

parsed = urlparse(proxy)
# Any userinfo in netloc should be masked to avoid leaking tokens/passwords.
if "@" in parsed.netloc and parsed.hostname:
host = parsed.hostname
if ":" in host and not host.startswith("["):
host = f"[{host}]"
netloc = f"****@{host}"
if parsed.port:
netloc += f":{parsed.port}"
return urlunparse(
(
parsed.scheme,
netloc,
parsed.path,
parsed.params,
parsed.query,
parsed.fragment,
)
)
except Exception:
return "****"
return proxy


def create_proxy_client(
provider_label: str,
proxy: str | None = None,
Expand All @@ -95,8 +147,27 @@ def create_proxy_client(

Returns:
An httpx.AsyncClient configured with the proxy, or None if no proxy

Raises:
ImportError: If SOCKS proxy is used but socksio is not installed
"""
if proxy:
logger.info(f"[{provider_label}] 使用代理: {proxy}")
return httpx.AsyncClient(proxy=proxy)
return None
if not proxy:
return None

sanitized_proxy = _sanitize_proxy_url(proxy)
logger.info(f"[{provider_label}] 使用代理: {sanitized_proxy}")

# Check for SOCKS proxy and provide helpful error if socksio is not installed
if _is_socks_proxy(proxy):
try:
import socksio # noqa: F401
except ImportError:
raise ImportError(
f"使用 SOCKS 代理需要安装 socksio 包。请运行以下命令安装:\n"
f" pip install 'httpx[socks]'\n"
f"或者:\n"
f" pip install socksio\n"
f"代理地址: {sanitized_proxy}"
) from None

return httpx.AsyncClient(proxy=proxy)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ dependencies = [
"shipyard-python-sdk>=0.2.4",
"shipyard-neo-sdk>=0.2.0",
"python-socks>=2.8.0",
"socksio>=1.0.0",
"packaging>=24.2",
]

Expand Down
118 changes: 118 additions & 0 deletions tests/unit/test_network_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Tests for network utility helpers."""

import builtins

from astrbot.core.utils import network_utils


def test_sanitize_proxy_url_masks_password_credentials():
proxy = "http://user:secret@127.0.0.1:1080"
assert network_utils._sanitize_proxy_url(proxy) == "http://****@127.0.0.1:1080"


def test_sanitize_proxy_url_masks_username_only_credentials():
proxy = "http://token@127.0.0.1:1080"
assert network_utils._sanitize_proxy_url(proxy) == "http://****@127.0.0.1:1080"


def test_sanitize_proxy_url_masks_empty_password_credentials():
proxy = "http://user:@127.0.0.1:1080"
assert network_utils._sanitize_proxy_url(proxy) == "http://****@127.0.0.1:1080"


def test_sanitize_proxy_url_returns_original_when_no_credentials():
proxy = "http://127.0.0.1:1080"
assert network_utils._sanitize_proxy_url(proxy) == proxy


def test_sanitize_proxy_url_returns_original_for_non_url_text():
proxy = "not a url"
assert network_utils._sanitize_proxy_url(proxy) == proxy


def test_sanitize_proxy_url_returns_original_for_empty_string():
assert network_utils._sanitize_proxy_url("") == ""


def test_sanitize_proxy_url_masks_credentials_for_ipv6_host():
proxy = "http://user:secret@[::1]:1080"
assert network_utils._sanitize_proxy_url(proxy) == "http://****@[::1]:1080"


def test_sanitize_proxy_url_falls_back_to_placeholder_on_parse_error(monkeypatch):
proxy = "http://user:secret@127.0.0.1:1080"
original_import = builtins.__import__

def guarded_import(name, globals_=None, locals_=None, fromlist=(), level=0):
if name == "urllib.parse":
raise ImportError("boom")
return original_import(name, globals_, locals_, fromlist, level)

monkeypatch.setattr(builtins, "__import__", guarded_import)

assert network_utils._sanitize_proxy_url(proxy) == "****"


def test_is_socks_proxy_detects_supported_schemes():
assert network_utils._is_socks_proxy("socks5://127.0.0.1:1080")
assert network_utils._is_socks_proxy("socks4://127.0.0.1:1080")
assert network_utils._is_socks_proxy("socks5h://127.0.0.1:1080")
assert not network_utils._is_socks_proxy("http://127.0.0.1:1080")


def test_log_connection_failure_redacts_proxy_in_error_text(monkeypatch):
proxy = "http://token@127.0.0.1:1080"
captured = {}

def fake_error(message: str):
captured["message"] = message

monkeypatch.setattr(network_utils.logger, "error", fake_error)

network_utils.log_connection_failure(
provider_label="OpenAI",
error=RuntimeError(f"proxy connect failed: {proxy}"),
proxy=proxy,
)

assert "http://token@127.0.0.1:1080" not in captured["message"]
assert "http://****@127.0.0.1:1080" in captured["message"]


def test_log_connection_failure_without_proxy_does_not_log_proxy_label(monkeypatch):
captured = {}

def fake_error(message: str):
captured["message"] = message

monkeypatch.setattr(network_utils.logger, "error", fake_error)
monkeypatch.delenv("http_proxy", raising=False)
monkeypatch.delenv("https_proxy", raising=False)

network_utils.log_connection_failure(
provider_label="OpenAI",
error=RuntimeError("connection failed"),
proxy=None,
)

assert "代理地址" not in captured["message"]
assert "connection failed" in captured["message"]


def test_log_connection_failure_keeps_error_text_when_no_proxy_text(monkeypatch):
proxy = "http://token@127.0.0.1:1080"
captured = {}

def fake_error(message: str):
captured["message"] = message

monkeypatch.setattr(network_utils.logger, "error", fake_error)

network_utils.log_connection_failure(
provider_label="OpenAI",
error=RuntimeError("connect timeout"),
proxy=proxy,
)

assert "http://****@127.0.0.1:1080" in captured["message"]
assert "connect timeout" in captured["message"]