diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 7644577594..d2fce17ded 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -18,6 +18,7 @@ from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.network_utils import ( + create_proxy_client, is_connection_error, log_connection_failure, ) @@ -106,15 +107,13 @@ def _init_api_key(self, provider_config: dict) -> None: http_client=self._create_http_client(provider_config), ) - def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None: - """创建带代理的 HTTP 客户端""" - proxy = provider_config.get("proxy", "") - if proxy: - logger.info(f"[Anthropic] 使用代理: {proxy}") - return httpx.AsyncClient(proxy=proxy, headers=self.custom_headers) - if self.custom_headers: - return httpx.AsyncClient(headers=self.custom_headers) - return None + def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient: + """创建带代理的 HTTP 客户端,使用系统 SSL 证书""" + return create_proxy_client( + "Anthropic", + provider_config.get("proxy", ""), + headers=self.custom_headers, + ) def _apply_thinking_config(self, payloads: dict) -> None: thinking_type = self.thinking_config.get("type", "") diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index b24bc0885b..67971a2a93 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -438,7 +438,7 @@ async def _fallback_to_text_only_and_retry( image_fallback_used, ) - def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None: + def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient: """创建带代理的 HTTP 客户端""" proxy = provider_config.get("proxy", "") return create_proxy_client("OpenAI", proxy) diff --git a/astrbot/core/utils/network_utils.py b/astrbot/core/utils/network_utils.py index 727f3762ae..047529396e 100644 --- a/astrbot/core/utils/network_utils.py +++ b/astrbot/core/utils/network_utils.py @@ -1,9 +1,13 @@ """Network error handling utilities for providers.""" +import ssl + import httpx from astrbot import logger +_SYSTEM_SSL_CTX = ssl.create_default_context() + def is_connection_error(exc: BaseException) -> bool: """Check if an exception is a connection/network related error. @@ -83,20 +87,30 @@ def log_connection_failure( def create_proxy_client( provider_label: str, proxy: str | None = None, -) -> httpx.AsyncClient | None: + headers: dict[str, str] | None = None, + verify: ssl.SSLContext | str | bool | None = None, +) -> httpx.AsyncClient: """Create an httpx AsyncClient with proxy configuration if provided. + Uses the system SSL certificate store instead of certifi, which avoids + SSL verification failures for endpoints whose CA chain is not in certifi + but is trusted by the operating system. + Note: The caller is responsible for closing the client when done. Consider using the client as a context manager or calling aclose() explicitly. Args: provider_label: The provider name for log prefix (e.g., "OpenAI", "Gemini") proxy: The proxy address (e.g., "http://127.0.0.1:7890"), or None/empty + headers: Optional custom headers to include in every request + verify: Optional override for TLS verification. Defaults to the shared + system SSL context when not provided. Returns: - An httpx.AsyncClient configured with the proxy, or None if no proxy + An httpx.AsyncClient created with the shared system SSL context; the proxy is applied only if one is provided. """ + resolved_verify = _SYSTEM_SSL_CTX if verify is None else verify if proxy: logger.info(f"[{provider_label}] 使用代理: {proxy}") - return httpx.AsyncClient(proxy=proxy) - return None + return httpx.AsyncClient(proxy=proxy, verify=resolved_verify, headers=headers) + return httpx.AsyncClient(verify=resolved_verify, headers=headers) diff --git a/tests/unit/test_network_utils.py b/tests/unit/test_network_utils.py new file mode 100644 index 0000000000..ea3505e387 --- /dev/null +++ b/tests/unit/test_network_utils.py @@ -0,0 +1,51 @@ +import ssl + +import pytest + +from astrbot.core.utils import network_utils + + +def test_create_proxy_client_reuses_shared_ssl_context( + monkeypatch: pytest.MonkeyPatch, +): + captured_calls: list[dict] = [] + headers = {"X-Test-Header": "value"} + + class _FakeAsyncClient: + def __init__(self, **kwargs): + captured_calls.append(kwargs) + + monkeypatch.setattr(network_utils.httpx, "AsyncClient", _FakeAsyncClient) + + network_utils.create_proxy_client("OpenAI") + network_utils.create_proxy_client("OpenAI", proxy="http://127.0.0.1:7890") + network_utils.create_proxy_client("OpenAI", headers=headers) + network_utils.create_proxy_client("OpenAI", proxy="") + + assert len(captured_calls) == 4 + assert "proxy" not in captured_calls[0] + assert captured_calls[1]["proxy"] == "http://127.0.0.1:7890" + assert captured_calls[2]["headers"] is headers + assert "proxy" not in captured_calls[3] + assert isinstance(captured_calls[0]["verify"], ssl.SSLContext) + assert captured_calls[0]["verify"] is captured_calls[1]["verify"] + assert captured_calls[1]["verify"] is captured_calls[2]["verify"] + assert captured_calls[2]["verify"] is captured_calls[3]["verify"] + + +def test_create_proxy_client_allows_verify_override( + monkeypatch: pytest.MonkeyPatch, +): + captured_calls: list[dict] = [] + custom_verify = ssl.create_default_context() + + class _FakeAsyncClient: + def __init__(self, **kwargs): + captured_calls.append(kwargs) + + monkeypatch.setattr(network_utils.httpx, "AsyncClient", _FakeAsyncClient) + + network_utils.create_proxy_client("OpenAI", verify=custom_verify) + + assert len(captured_calls) == 1 + assert captured_calls[0]["verify"] is custom_verify