Skip to content

Commit 281c44f

Browse files
committed
refactor: streamline provider model cache updates
1 parent 5371267 commit 281c44f

1 file changed

Lines changed: 63 additions & 27 deletions

File tree

  • astrbot/builtin_stars/builtin_commands/commands

astrbot/builtin_stars/builtin_commands/commands/provider.py

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,27 @@ def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> N
3131
return
3232
self._provider_models_cache.pop(provider_id, None)
3333

34+
def _invalidate_cache_for(self, provider: "Provider") -> None:
35+
self._invalidate_provider_models_cache(provider.meta().id)
36+
37+
def _set_model_and_invalidate(self, provider: "Provider", model_name: str) -> None:
38+
provider.set_model(model_name)
39+
self._invalidate_cache_for(provider)
40+
41+
def _set_key_and_invalidate(self, provider: "Provider", key: str) -> None:
42+
provider.set_key(key)
43+
self._invalidate_cache_for(provider)
44+
3445
@staticmethod
3546
def _mask_sensitive_text(value: str) -> str:
3647
return _API_KEY_PATTERN.sub("key=***", value)
3748

3849
def _safe_err(self, e: BaseException) -> str:
3950
return self._mask_sensitive_text(str(e))
4051

52+
def _format_err(self, prefix: str, e: BaseException) -> str:
53+
return f"{prefix}{self._safe_err(e)}"
54+
4155
async def _get_provider_models(
4256
self, provider: "Provider", *, use_cache: bool = True
4357
) -> list[str]:
@@ -96,18 +110,38 @@ async def _find_provider_for_model(
96110
]
97111
if not all_providers:
98112
return None
99-
results = await asyncio.gather(
100-
*[self._get_provider_models(p) for p in all_providers],
101-
return_exceptions=True,
102-
)
113+
114+
async def _fetch_models(
115+
provider: "Provider",
116+
) -> tuple["Provider", list[str] | None, BaseException | None]:
117+
try:
118+
return provider, await self._get_provider_models(provider), None
119+
except BaseException as e:
120+
return provider, None, e
121+
122+
tasks = [
123+
asyncio.create_task(_fetch_models(provider)) for provider in all_providers
124+
]
103125
failed_provider_errors: list[tuple[str, str]] = []
104-
for provider, result in zip(all_providers, results):
105-
if isinstance(result, BaseException):
106-
masked_error = self._safe_err(result)
126+
matched_provider: Provider | None = None
127+
for task in asyncio.as_completed(tasks):
128+
provider, models, error = await task
129+
if error is not None:
130+
masked_error = self._safe_err(error)
107131
failed_provider_errors.append((provider.meta().id, masked_error))
108132
continue
109-
if model_name in result:
110-
return provider
133+
134+
if models is not None and model_name in models:
135+
matched_provider = provider
136+
break
137+
138+
if matched_provider is not None:
139+
for task in tasks:
140+
if not task.done():
141+
task.cancel()
142+
await asyncio.gather(*tasks, return_exceptions=True)
143+
return matched_provider
144+
111145
if failed_provider_errors and len(failed_provider_errors) == len(all_providers):
112146
failed_ids = ",".join(
113147
provider_id for provider_id, _ in failed_provider_errors
@@ -334,13 +368,14 @@ async def _switch_model_by_name(
334368
err_msg,
335369
)
336370
message.set_result(
337-
MessageEventResult().message("获取当前提供商模型列表失败: " + err_msg),
371+
MessageEventResult().message(
372+
self._format_err("获取当前提供商模型列表失败: ", e)
373+
)
338374
)
339375
return
340376

341377
if model_name in models:
342-
prov.set_model(model_name)
343-
self._invalidate_provider_models_cache(curr_provider_id)
378+
self._set_model_and_invalidate(prov, model_name)
344379
message.set_result(
345380
MessageEventResult().message(
346381
f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{model_name}]",
@@ -375,17 +410,17 @@ async def _switch_model_by_name(
375410
provider_type=ProviderType.CHAT_COMPLETION,
376411
umo=umo,
377412
)
378-
target_prov.set_model(model_name)
379-
self._invalidate_provider_models_cache(target_id)
413+
self._set_model_and_invalidate(target_prov, model_name)
380414
message.set_result(
381415
MessageEventResult().message(
382416
f"检测到模型 [{model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。",
383417
),
384418
)
385419
except BaseException as e:
386-
err_msg = self._safe_err(e)
387420
message.set_result(
388-
MessageEventResult().message("跨提供商切换并设置模型失败: " + err_msg),
421+
MessageEventResult().message(
422+
self._format_err("跨提供商切换并设置模型失败: ", e)
423+
),
389424
)
390425

391426
async def model_ls(
@@ -406,10 +441,9 @@ async def model_ls(
406441
try:
407442
models = await self._get_provider_models(prov)
408443
except BaseException as e:
409-
err_msg = self._safe_err(e)
410444
message.set_result(
411445
MessageEventResult()
412-
.message("获取模型列表失败: " + err_msg)
446+
.message(self._format_err("获取模型列表失败: ", e))
413447
.use_t2i(False),
414448
)
415449
return
@@ -430,27 +464,28 @@ async def model_ls(
430464
try:
431465
models = await self._get_provider_models(prov)
432466
except BaseException as e:
433-
err_msg = self._safe_err(e)
434467
message.set_result(
435-
MessageEventResult().message("获取模型列表失败: " + err_msg),
468+
MessageEventResult().message(
469+
self._format_err("获取模型列表失败: ", e)
470+
),
436471
)
437472
return
438473
if idx_or_name > len(models) or idx_or_name < 1:
439474
message.set_result(MessageEventResult().message("模型序号错误。"))
440475
else:
441476
try:
442477
new_model = models[idx_or_name - 1]
443-
prov.set_model(new_model)
444-
self._invalidate_provider_models_cache(prov.meta().id)
478+
self._set_model_and_invalidate(prov, new_model)
445479
message.set_result(
446480
MessageEventResult().message(
447481
f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]",
448482
),
449483
)
450484
except BaseException as e:
451-
err_msg = self._safe_err(e)
452485
message.set_result(
453-
MessageEventResult().message("切换模型未知错误: " + err_msg),
486+
MessageEventResult().message(
487+
self._format_err("切换模型未知错误: ", e)
488+
),
454489
)
455490
return
456491
else:
@@ -484,10 +519,11 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None
484519
else:
485520
try:
486521
new_key = keys_data[index - 1]
487-
prov.set_key(new_key)
488-
self._invalidate_provider_models_cache(prov.meta().id)
522+
self._set_key_and_invalidate(prov, new_key)
489523
except BaseException as e:
490524
message.set_result(
491-
MessageEventResult().message(f"切换 Key 未知错误: {e!s}"),
525+
MessageEventResult().message(
526+
self._format_err("切换 Key 未知错误: ", e)
527+
),
492528
)
493529
message.set_result(MessageEventResult().message("切换 Key 成功。"))

0 commit comments

Comments
 (0)