Skip to content
Merged
Changes from 5 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e570469
fix: /model command now auto-switches provider when model exists else…
pandyzhou Feb 28, 2026
ec7995e
fix: address Sourcery review - log get_models() failures in cross-pro…
pandyzhou Feb 28, 2026
f9cd842
fix: integer branch exception handling and API key masking in model c…
pandyzhou Feb 28, 2026
83fb0e8
fix: harden cross-provider model resolution
zouyonghe Mar 1, 2026
4d5c8ae
fix: improve model lookup resilience and cache hygiene
zouyonghe Mar 1, 2026
5371267
refactor: simplify model switch lookup flow
zouyonghe Mar 1, 2026
281c44f
refactor: streamline provider model cache updates
zouyonghe Mar 1, 2026
ba1b1ff
fix: align provider annotations and key error flow
zouyonghe Mar 1, 2026
d062abf
fix: narrow provider command exception handling
zouyonghe Mar 1, 2026
1d611d4
refactor: harden provider command error redaction and flow
zouyonghe Mar 1, 2026
5293aef
fix: improve provider model lookup and secret redaction
zouyonghe Mar 1, 2026
418a405
refactor: cache normalized model names in provider lookup
zouyonghe Mar 1, 2026
abe31a3
refactor: simplify provider model lookup helpers
zouyonghe Mar 1, 2026
34235c6
refactor: extract provider model lookup helpers
zouyonghe Mar 1, 2026
40b7fd3
fix: harden provider lookup cancellation and redaction
zouyonghe Mar 1, 2026
cf7da2f
refactor: streamline provider cache and lookup settings
zouyonghe Mar 1, 2026
5aeccbb
refactor: simplify provider command setting and update helpers
zouyonghe Mar 1, 2026
747bede
refactor: streamline provider model lookup config usage
zouyonghe Mar 1, 2026
9059c1b
refactor: flatten provider lookup settings and filter model lookup pr…
zouyonghe Mar 1, 2026
b76ba0c
refactor: simplify provider cache and callback flow
zouyonghe Mar 1, 2026
f139be7
refactor: simplify provider command model cache flow
zouyonghe Mar 1, 2026
086a11e
refactor: scope provider model cache by session
zouyonghe Mar 1, 2026
597690e
fix: preserve redaction context and restore provider hooks
zouyonghe Mar 1, 2026
9c5f31f
refactor: unify provider model lookup config flow
zouyonghe Mar 1, 2026
161763e
refactor: inline provider model cache access flow
zouyonghe Mar 1, 2026
7d1f642
fix: align provider lookup cache and callback semantics
zouyonghe Mar 1, 2026
a9f1558
refactor: centralize provider model fetch error handling
zouyonghe Mar 1, 2026
343d277
refactor: simplify provider model cache and lookup flow
zouyonghe Mar 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 158 additions & 13 deletions astrbot/builtin_stars/builtin_commands/commands/provider.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,49 @@
import asyncio
import re
import time
from typing import TYPE_CHECKING

from astrbot import logger
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType

if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider

_API_KEY_PATTERN = re.compile(r"key=[^&'\" ]+")
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated


class ProviderCommands:
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
_MODEL_LIST_CACHE_TTL_SECONDS = 30.0

def __init__(self, context: star.Context) -> None:
self.context = context
self._provider_models_cache: dict[str, tuple[float, list[str]]] = {}

def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> None:
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
if provider_id is None:
self._provider_models_cache.clear()
return
self._provider_models_cache.pop(provider_id, None)

@staticmethod
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
def _mask_sensitive_text(value: str) -> str:
return _API_KEY_PATTERN.sub("key=***", value)

async def _get_provider_models(
self, provider: "Provider", *, use_cache: bool = True
) -> list[str]:
provider_id = provider.meta().id
now = time.monotonic()
if use_cache:
cached = self._provider_models_cache.get(provider_id)
if cached and now - cached[0] <= self._MODEL_LIST_CACHE_TTL_SECONDS:
return list(cached[1])

models = list(await provider.get_models())
self._provider_models_cache[provider_id] = (now, models)
return list(models)

def _log_reachability_failure(
self,
Expand Down Expand Up @@ -44,6 +78,52 @@ async def _test_provider_capability(self, provider):
)
return False, err_code, err_reason

async def _find_provider_for_model(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
self, model_name: str, exclude_provider_id: str | None = None
) -> tuple["Provider" | None, str | None]:
"""在所有 LLM 提供商中查找包含指定模型的提供商。返回 (provider, provider_id) 或 (None, None)。"""
all_providers = [
p
for p in self.context.get_all_providers()
if not exclude_provider_id or p.meta().id != exclude_provider_id
]
if not all_providers:
return None, None
results = await asyncio.gather(
*[self._get_provider_models(p) for p in all_providers],
return_exceptions=True,
)
failed_provider_errors: list[tuple[str, str]] = []
for provider, result in zip(all_providers, results):
if isinstance(result, BaseException):
masked_error = self._mask_sensitive_text(str(result))
failed_provider_errors.append((provider.meta().id, masked_error))
continue
provider_id = provider.meta().id
if model_name in result:
return provider, provider_id
if failed_provider_errors and len(failed_provider_errors) == len(all_providers):
failed_ids = ",".join(
provider_id for provider_id, _ in failed_provider_errors
)
logger.error(
"跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败: %s。请检查配置或网络",
model_name,
len(all_providers),
failed_ids,
)
elif failed_provider_errors:
logger.debug(
"跨提供商查找模型 %s 时有 %d 个提供商获取模型失败: %s",
model_name,
len(failed_provider_errors),
",".join(
f"{provider_id}({error})"
for provider_id, error in failed_provider_errors
),
)
return None, None
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Comment thread
pandyzhou marked this conversation as resolved.

async def provider(
self,
event: AstrMessageEvent,
Expand Down Expand Up @@ -236,15 +316,13 @@ async def model_ls(
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
)
return
# 定义正则表达式匹配 API 密钥
api_key_pattern = re.compile(r"key=[^&'\" ]+")

if idx_or_name is None:
models = []
try:
models = await prov.get_models()
models = await self._get_provider_models(prov)
except BaseException as e:
err_msg = api_key_pattern.sub("key=***", str(e))
err_msg = self._mask_sensitive_text(str(e))
message.set_result(
MessageEventResult()
.message("获取模型列表失败: " + err_msg)
Expand All @@ -258,18 +336,19 @@ async def model_ls(
curr_model = prov.get_model() or "无"
parts.append(f"\n当前模型: [{curr_model}]")
parts.append(
"\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
"\nTips: 使用 /model <模型名/编号> 切换模型。输入模型名时可自动跨提供商查找并切换;跨提供商也可使用 /provider 切换。"
)

ret = "".join(parts)
message.set_result(MessageEventResult().message(ret).use_t2i(False))
elif isinstance(idx_or_name, int):
models = []
try:
models = await prov.get_models()
models = await self._get_provider_models(prov)
except BaseException as e:
err_msg = self._mask_sensitive_text(str(e))
message.set_result(
MessageEventResult().message("获取模型列表失败: " + str(e)),
MessageEventResult().message("获取模型列表失败: " + err_msg),
)
return
if idx_or_name > len(models) or idx_or_name < 1:
Expand All @@ -278,20 +357,85 @@ async def model_ls(
try:
new_model = models[idx_or_name - 1]
prov.set_model(new_model)
self._invalidate_provider_models_cache(prov.meta().id)
message.set_result(
MessageEventResult().message(
f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]",
),
)
except BaseException as e:
err_msg = self._mask_sensitive_text(str(e))
message.set_result(
MessageEventResult().message("切换模型未知错误: " + str(e)),
MessageEventResult().message("切换模型未知错误: " + err_msg),
)
return
else:
# 字符串:模型名,需智能解析是否跨提供商
model_name = idx_or_name.strip()
umo = message.unified_msg_origin
curr_provider_id = prov.meta().id
if not model_name:
message.set_result(MessageEventResult().message("模型名不能为空。"))
return

# 1. 检查当前提供商
models = []
try:
models = await self._get_provider_models(prov)
except BaseException as e:
err_msg = self._mask_sensitive_text(str(e))
logger.warning(
"获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s",
curr_provider_id,
err_msg,
)
message.set_result(
MessageEventResult().message(
f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]",
"获取当前提供商模型列表失败: " + err_msg
),
)
else:
prov.set_model(idx_or_name)
message.set_result(
MessageEventResult().message(f"切换模型到 {prov.get_model()}。"),
return
if model_name in models:
prov.set_model(model_name)
Comment thread
pandyzhou marked this conversation as resolved.
Outdated
self._invalidate_provider_models_cache(curr_provider_id)
message.set_result(
MessageEventResult().message(
f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{model_name}]",
),
)
return

# 2. 在其他提供商中查找
target_prov, target_id = await self._find_provider_for_model(
model_name, exclude_provider_id=curr_provider_id
)
if target_prov and target_id:
try:
await self.context.provider_manager.set_provider(
provider_id=target_id,
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
target_prov.set_model(model_name)
self._invalidate_provider_models_cache(target_id)
message.set_result(
MessageEventResult().message(
f"检测到模型 [{model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。",
),
)
except BaseException as e:
err_msg = self._mask_sensitive_text(str(e))
message.set_result(
MessageEventResult().message(
"跨提供商切换并设置模型失败: " + err_msg
),
)
else:
message.set_result(
MessageEventResult().message(
f"模型 [{model_name}] 未在任何已配置的提供商中找到。请使用 /provider 切换到目标提供商,或确认模型名正确。",
),
)

async def key(self, message: AstrMessageEvent, index: int | None = None) -> None:
prov = self.context.get_using_provider(message.unified_msg_origin)
Expand Down Expand Up @@ -322,6 +466,7 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None
try:
new_key = keys_data[index - 1]
prov.set_key(new_key)
self._invalidate_provider_models_cache(prov.meta().id)
except BaseException as e:
message.set_result(
MessageEventResult().message(f"切换 Key 未知错误: {e!s}"),
Expand Down