Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions astrbot/core/provider/sources/anthropic_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.network_utils import (
create_proxy_client,
is_connection_error,
log_connection_failure,
)
Expand Down Expand Up @@ -106,15 +107,13 @@ def _init_api_key(self, provider_config: dict) -> None:
http_client=self._create_http_client(provider_config),
)

def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None:
"""创建带代理的 HTTP 客户端"""
proxy = provider_config.get("proxy", "")
if proxy:
logger.info(f"[Anthropic] 使用代理: {proxy}")
return httpx.AsyncClient(proxy=proxy, headers=self.custom_headers)
if self.custom_headers:
return httpx.AsyncClient(headers=self.custom_headers)
return None
def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient:
"""创建带代理的 HTTP 客户端,使用系统 SSL 证书"""
return create_proxy_client(
"Anthropic",
provider_config.get("proxy", ""),
headers=self.custom_headers,
)

def _apply_thinking_config(self, payloads: dict) -> None:
thinking_type = self.thinking_config.get("type", "")
Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ async def _fallback_to_text_only_and_retry(
image_fallback_used,
)

def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None:
def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient:
"""创建带代理的 HTTP 客户端"""
proxy = provider_config.get("proxy", "")
return create_proxy_client("OpenAI", proxy)
Expand Down
22 changes: 18 additions & 4 deletions astrbot/core/utils/network_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Network error handling utilities for providers."""

import ssl

import httpx

from astrbot import logger

_SYSTEM_SSL_CTX = ssl.create_default_context()


def is_connection_error(exc: BaseException) -> bool:
"""Check if an exception is a connection/network related error.
Expand Down Expand Up @@ -83,20 +87,30 @@ def log_connection_failure(
def create_proxy_client(
provider_label: str,
proxy: str | None = None,
) -> httpx.AsyncClient | None:
headers: dict[str, str] | None = None,
verify: ssl.SSLContext | str | bool | None = None,
) -> httpx.AsyncClient:
"""Create an httpx AsyncClient with proxy configuration if provided.

Uses the system SSL certificate store instead of certifi, which avoids
SSL verification failures for endpoints whose CA chain is not in certifi
but is trusted by the operating system.

Note: The caller is responsible for closing the client when done.
Consider using the client as a context manager or calling aclose() explicitly.

Args:
provider_label: The provider name for log prefix (e.g., "OpenAI", "Gemini")
proxy: The proxy address (e.g., "http://127.0.0.1:7890"), or None/empty
headers: Optional custom headers to include in every request
verify: Optional override for TLS verification. Defaults to the shared
system SSL context when not provided.

Returns:
An httpx.AsyncClient configured with the proxy, or None if no proxy
An httpx.AsyncClient created with the shared system SSL context; the proxy is applied only if one is provided.
"""
resolved_verify = _SYSTEM_SSL_CTX if verify is None else verify
if proxy:
logger.info(f"[{provider_label}] 使用代理: {proxy}")
return httpx.AsyncClient(proxy=proxy)
return None
return httpx.AsyncClient(proxy=proxy, verify=resolved_verify, headers=headers)
return httpx.AsyncClient(verify=resolved_verify, headers=headers)
51 changes: 51 additions & 0 deletions tests/unit/test_network_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import ssl

import pytest

from astrbot.core.utils import network_utils


def test_create_proxy_client_reuses_shared_ssl_context(
monkeypatch: pytest.MonkeyPatch,
):
captured_calls: list[dict] = []
headers = {"X-Test-Header": "value"}

class _FakeAsyncClient:
def __init__(self, **kwargs):
captured_calls.append(kwargs)

monkeypatch.setattr(network_utils.httpx, "AsyncClient", _FakeAsyncClient)
Comment on lines +8 to +18
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add assertions to verify that proxy and headers arguments are correctly forwarded to httpx.AsyncClient.

This test covers SSL context reuse but doesn’t yet validate that headers and proxy are forwarded correctly. Please extend it (or add a new test) to assert that:

  • When proxy is omitted, AsyncClient is called without a proxy kwarg.
  • When proxy is provided, AsyncClient receives the expected proxy value.
  • When headers is provided to create_proxy_client, it is forwarded as the headers kwarg.

You can keep using _FakeAsyncClient and check captured_calls[i]["proxy"] and captured_calls[i]["headers"].


network_utils.create_proxy_client("OpenAI")
network_utils.create_proxy_client("OpenAI", proxy="http://127.0.0.1:7890")
network_utils.create_proxy_client("OpenAI", headers=headers)
Comment on lines +20 to +22
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add an explicit test for an empty proxy string to match how provider_config.get("proxy", "") is used.

Since create_proxy_client is frequently called with provider_config.get("proxy", ""), it can receive an empty string, which the helper treats as "no proxy". Please add a case like create_proxy_client("OpenAI", proxy="") and assert it matches the no-proxy behavior (no proxy in kwargs, same verify handling) so tests cover this usage explicitly.

network_utils.create_proxy_client("OpenAI", proxy="")

assert len(captured_calls) == 4
assert "proxy" not in captured_calls[0]
assert captured_calls[1]["proxy"] == "http://127.0.0.1:7890"
assert captured_calls[2]["headers"] is headers
assert "proxy" not in captured_calls[3]
assert isinstance(captured_calls[0]["verify"], ssl.SSLContext)
assert captured_calls[0]["verify"] is captured_calls[1]["verify"]
assert captured_calls[1]["verify"] is captured_calls[2]["verify"]
assert captured_calls[2]["verify"] is captured_calls[3]["verify"]


def test_create_proxy_client_allows_verify_override(
monkeypatch: pytest.MonkeyPatch,
):
captured_calls: list[dict] = []
custom_verify = ssl.create_default_context()

class _FakeAsyncClient:
def __init__(self, **kwargs):
captured_calls.append(kwargs)

monkeypatch.setattr(network_utils.httpx, "AsyncClient", _FakeAsyncClient)

network_utils.create_proxy_client("OpenAI", verify=custom_verify)

assert len(captured_calls) == 1
assert captured_calls[0]["verify"] is custom_verify
Loading