Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 79 additions & 15 deletions astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,34 +459,77 @@ def __init__(self, provider_config, provider_settings) -> None:
for key in self.custom_headers:
self.custom_headers[key] = str(self.custom_headers[key])

if "api_version" in provider_config:
self.client = self._create_openai_client()

self.default_params = inspect.signature(
self.client.chat.completions.create,
).parameters.keys()

model = provider_config.get("model", "unknown")
self.set_model(model)

self.reasoning_key = "reasoning_content"

def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None:
"""创建带代理的 HTTP 客户端"""
proxy = provider_config.get("proxy", "")

return create_proxy_client("OpenAI", proxy)

def _create_openai_client(self) -> AsyncOpenAI | AsyncAzureOpenAI:
"""创建 OpenAI/Azure 客户端实例,将初始化逻辑解耦以便复用。"""
if "api_version" in self.provider_config:
# Using Azure OpenAI API
self.client = AsyncAzureOpenAI(
return AsyncAzureOpenAI(
api_key=self.chosen_api_key,
api_version=provider_config.get("api_version", None),
api_version=self.provider_config.get("api_version", None),
default_headers=self.custom_headers,
base_url=provider_config.get("api_base", ""),
base_url=self.provider_config.get("api_base", ""),
timeout=self.timeout,
http_client=self._create_http_client(provider_config),
http_client=self._create_http_client(self.provider_config),
)
else:
# Using OpenAI Official API
self.client = AsyncOpenAI(
return AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
base_url=self.provider_config.get("api_base", None),
default_headers=self.custom_headers,
timeout=self.timeout,
http_client=self._create_http_client(provider_config),
http_client=self._create_http_client(self.provider_config),
)

self.default_params = inspect.signature(
self.client.chat.completions.create,
).parameters.keys()
def _is_underlying_client_closed(self) -> bool:
"""集中处理对 openai SDK 私有属性的访问,便于未来替换为公开 API。

model = provider_config.get("model", "unknown")
self.set_model(model)
注意:此处直接访问了 openai 库的私有属性 `_client`,
依赖其内部实现(httpx.AsyncClient 实例暴露的 is_closed 属性)。
若 openai 库未来版本调整了内部结构,此处可能失效。
目前 openai SDK 尚未提供检查底层连接是否已关闭的公开 API。
若未来 SDK 提供了类似 self.client.is_closed() 的公开方法,
应及时将此处替换为对应的公开接口。

self.reasoning_key = "reasoning_content"
当检测逻辑因 SDK 内部结构变更而抛出 AttributeError 时,会:
1. 记录 warning 日志,提示可能的 SDK 变更;
2. 保守地视为"已关闭",触发后续的 client 重建逻辑。
"""
try:
return bool(self.client and self.client._client.is_closed)
except AttributeError:
logger.warning(
"无法检测 OpenAI client 是否已关闭,"
"可能是 SDK 内部结构变更导致;"
"将保守视为已关闭并触发 client 重建。"
)
return True

def _ensure_client(self) -> None:
"""确保 client 可用。如果 client 为 None 或底层连接已关闭,则重新创建。"""
if self.client is None or self._is_underlying_client_closed():
logger.warning("检测到 OpenAI client 已关闭或未初始化,正在重新创建...")
self.client = self._create_openai_client()
self.default_params = inspect.signature(
self.client.chat.completions.create,
).parameters.keys()

def _ollama_disable_thinking_enabled(self) -> bool:
value = self.provider_config.get("ollama_disable_thinking", False)
Expand All @@ -509,6 +552,7 @@ def _apply_provider_specific_extra_body_overrides(
extra_body["reasoning_effort"] = "none"

async def get_models(self):
self._ensure_client()
try:
models_str = []
models = await self.client.models.list()
Expand All @@ -520,6 +564,7 @@ async def get_models(self):
raise Exception(f"获取模型列表失败:{e}")

async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
self._ensure_client()
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
Expand Down Expand Up @@ -592,6 +637,7 @@ async def _query_stream(
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API,逐步返回结果"""
self._ensure_client()
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
Expand Down Expand Up @@ -1145,6 +1191,8 @@ async def text_chat(
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self._ensure_client()
self.chosen_api_key = chosen_key
self.client.api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
break
Expand Down Expand Up @@ -1216,6 +1264,8 @@ async def text_chat_stream(
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self._ensure_client()
self.chosen_api_key = chosen_key
self.client.api_key = chosen_key
async for response in self._query_stream(payloads, func_tool):
yield response
Expand Down Expand Up @@ -1270,12 +1320,15 @@ async def _remove_image_from_context(self, contexts: list):
return new_contexts

def get_current_key(self) -> str:
self._ensure_client()
return self.client.api_key

def get_keys(self) -> list[str]:
return self.api_keys

def set_key(self, key) -> None:
self.chosen_api_key = key
self._ensure_client()
self.client.api_key = key

async def assemble_context(
Expand Down Expand Up @@ -1355,5 +1408,16 @@ async def encode_image_bs64(self, image_url: str) -> str:
return image_data

async def terminate(self):
"""关闭 client 并将引用置为 None。

通过 try/finally 确保即使 close() 抛出异常,
self.client 也会被清空,避免配置重载(reload)期间
复用已关闭的 client 导致 APIConnectionError。
"""
if self.client:
await self.client.close()
try:
await self.client.close()
except Exception as e:
logger.warning(f"关闭 OpenAI client 时出错: {e}")
finally:
self.client = None