diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index b19f3460dd..7ee5b67ee4 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -459,34 +459,77 @@ def __init__(self, provider_config, provider_settings) -> None: for key in self.custom_headers: self.custom_headers[key] = str(self.custom_headers[key]) - if "api_version" in provider_config: + self.client = self._create_openai_client() + + self.default_params = inspect.signature( + self.client.chat.completions.create, + ).parameters.keys() + + model = provider_config.get("model", "unknown") + self.set_model(model) + + self.reasoning_key = "reasoning_content" + + def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None: + """创建带代理的 HTTP 客户端""" + proxy = provider_config.get("proxy", "") + + return create_proxy_client("OpenAI", proxy) + + def _create_openai_client(self) -> AsyncOpenAI | AsyncAzureOpenAI: + """创建 OpenAI/Azure 客户端实例,将初始化逻辑解耦以便复用。""" + if "api_version" in self.provider_config: # Using Azure OpenAI API - self.client = AsyncAzureOpenAI( + return AsyncAzureOpenAI( api_key=self.chosen_api_key, - api_version=provider_config.get("api_version", None), + api_version=self.provider_config.get("api_version", None), default_headers=self.custom_headers, - base_url=provider_config.get("api_base", ""), + base_url=self.provider_config.get("api_base", ""), timeout=self.timeout, - http_client=self._create_http_client(provider_config), + http_client=self._create_http_client(self.provider_config), ) else: # Using OpenAI Official API - self.client = AsyncOpenAI( + return AsyncOpenAI( api_key=self.chosen_api_key, - base_url=provider_config.get("api_base", None), + base_url=self.provider_config.get("api_base", None), default_headers=self.custom_headers, timeout=self.timeout, - http_client=self._create_http_client(provider_config), + http_client=self._create_http_client(self.provider_config), ) - self.default_params = inspect.signature( - self.client.chat.completions.create, - ).parameters.keys() + def _is_underlying_client_closed(self) -> bool: + """集中处理对 openai SDK 私有属性的访问,便于未来替换为公开 API。 - model = provider_config.get("model", "unknown") - self.set_model(model) + 注意:此处直接访问了 openai 库的私有属性 `_client`, + 依赖其内部实现(httpx.AsyncClient 实例暴露的 is_closed 属性)。 + 若 openai 库未来版本调整了内部结构,此处可能失效。 + 目前 openai SDK 尚未提供检查底层连接是否已关闭的公开 API。 + 若未来 SDK 提供了类似 self.client.is_closed() 的公开方法, + 应及时将此处替换为对应的公开接口。 - self.reasoning_key = "reasoning_content" + 当检测逻辑因 SDK 内部结构变更而抛出 AttributeError 时,会: + 1. 记录 warning 日志,提示可能的 SDK 变更; + 2. 保守地视为"已关闭",触发后续的 client 重建逻辑。 + """ + try: + return bool(self.client and self.client._client.is_closed) + except AttributeError: + logger.warning( + "无法检测 OpenAI client 是否已关闭," + "可能是 SDK 内部结构变更导致;" + "将保守视为已关闭并触发 client 重建。" + ) + return True + + def _ensure_client(self) -> None: + """确保 client 可用。如果 client 为 None 或底层连接已关闭,则重新创建。""" + if self.client is None or self._is_underlying_client_closed(): + logger.warning("检测到 OpenAI client 已关闭或未初始化,正在重新创建...") + self.client = self._create_openai_client() + self.default_params = inspect.signature( + self.client.chat.completions.create, + ).parameters.keys() def _ollama_disable_thinking_enabled(self) -> bool: value = self.provider_config.get("ollama_disable_thinking", False) @@ -509,6 +552,7 @@ def _apply_provider_specific_extra_body_overrides( extra_body["reasoning_effort"] = "none" async def get_models(self): + self._ensure_client() try: models_str = [] models = await self.client.models.list() @@ -520,6 +564,7 @@ async def get_models(self): raise Exception(f"获取模型列表失败:{e}") async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: + self._ensure_client() if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -592,6 +637,7 @@ async def _query_stream( tools: ToolSet | None, ) -> AsyncGenerator[LLMResponse, None]: """流式查询API,逐步返回结果""" + self._ensure_client() if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -1145,6 +1191,8 @@ async def text_chat( retry_cnt = 0 for retry_cnt in range(max_retries): try: + self._ensure_client() + self.chosen_api_key = chosen_key self.client.api_key = chosen_key llm_response = await self._query(payloads, func_tool) break @@ -1216,6 +1264,8 @@ async def text_chat_stream( retry_cnt = 0 for retry_cnt in range(max_retries): try: + self._ensure_client() + self.chosen_api_key = chosen_key self.client.api_key = chosen_key async for response in self._query_stream(payloads, func_tool): yield response @@ -1270,12 +1320,15 @@ async def _remove_image_from_context(self, contexts: list): return new_contexts def get_current_key(self) -> str: + self._ensure_client() return self.client.api_key def get_keys(self) -> list[str]: return self.api_keys def set_key(self, key) -> None: + self.chosen_api_key = key + self._ensure_client() self.client.api_key = key async def assemble_context( @@ -1355,5 +1408,16 @@ async def encode_image_bs64(self, image_url: str) -> str: return image_data async def terminate(self): + """关闭 client 并将引用置为 None。 + + 通过 try/finally 确保即使 close() 抛出异常, + self.client 也会被清空,避免配置重载(reload)期间 + 复用已关闭的 client 导致 APIConnectionError。 + """ if self.client: - await self.client.close() + try: + await self.client.close() + except Exception as e: + logger.warning(f"关闭 OpenAI client 时出错: {e}") + finally: + self.client = None