Skip to content

Commit 5293aef

Browse files
committed
fix: improve provider model lookup and secret redaction
1 parent 1d611d4 commit 5293aef

1 file changed

Lines changed: 90 additions & 45 deletions

File tree

  • astrbot/builtin_stars/builtin_commands/commands

astrbot/builtin_stars/builtin_commands/commands/provider.py

Lines changed: 90 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
_SECRET_PATTERNS = [
1818
re.compile(
19-
r"(?i)\b(api_?key|access_?token|token|secret|auth_?token|session_?id|password)\s*=\s*[^&'\" ]+"
19+
r"(?i)\b(api_?key|key|access_?token|token|secret|auth_?token|session_?id|password)\s*=\s*[^&'\" ]+"
2020
),
2121
re.compile(r"(?i)\bauthorization\s*:\s*bearer\s+[A-Za-z0-9._\-]+"),
2222
re.compile(r"(?i)\bbearer\s+[A-Za-z0-9._\-]+"),
@@ -39,6 +39,7 @@ class _ModelCacheEntry:
3939

4040
class ProviderCommands:
4141
_MODEL_LIST_CACHE_TTL_SECONDS = 30.0
42+
_MODEL_LOOKUP_MAX_CONCURRENCY = 4
4243

4344
def __init__(self, context: star.Context) -> None:
4445
self.context = context
@@ -50,23 +51,47 @@ def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> N
5051
return
5152
self._provider_models_cache.pop(provider_id, None)
5253

53-
def _invalidate_cache_for(self, provider: Provider) -> None:
54+
def _update_provider_and_invalidate(
55+
self,
56+
provider: Provider,
57+
*,
58+
model_name: str | None = None,
59+
key: str | None = None,
60+
) -> None:
61+
if model_name is not None:
62+
provider.set_model(model_name)
63+
if key is not None:
64+
provider.set_key(key)
5465
self._invalidate_provider_models_cache(provider.meta().id)
5566

56-
def _set_model_and_invalidate(self, provider: Provider, model_name: str) -> None:
57-
provider.set_model(model_name)
58-
self._invalidate_cache_for(provider)
59-
60-
def _set_key_and_invalidate(self, provider: Provider, key: str) -> None:
61-
provider.set_key(key)
62-
self._invalidate_cache_for(provider)
67+
@staticmethod
68+
def _safe_err(prefix: str, e: Exception) -> str:
69+
return prefix + redact_secrets(str(e))
6370

6471
@staticmethod
65-
def _mask_sensitive_text(value: str) -> str:
66-
return redact_secrets(value)
72+
def _normalize_model_name(model_name: str) -> str:
73+
return model_name.strip().casefold()
6774

68-
def _format_safe_err(self, prefix: str, e: Exception) -> str:
69-
return f"{prefix}{self._mask_sensitive_text(str(e))}"
75+
def _resolve_model_name(self, model_name: str, models: list[str]) -> str | None:
76+
normalized_model_name = self._normalize_model_name(model_name)
77+
if not normalized_model_name:
78+
return None
79+
if model_name in models:
80+
return model_name
81+
82+
for candidate in models:
83+
normalized_candidate = self._normalize_model_name(candidate)
84+
if normalized_candidate == normalized_model_name:
85+
return candidate
86+
if normalized_candidate.endswith(
87+
f"/{normalized_model_name}"
88+
) or normalized_candidate.endswith(f":{normalized_model_name}"):
89+
return candidate
90+
if normalized_model_name.endswith(
91+
f"/{normalized_candidate}"
92+
) or normalized_model_name.endswith(f":{normalized_candidate}"):
93+
return candidate
94+
return None
7095

7196
def _get_model_cache_ttl_seconds(self, umo: str | None = None) -> float:
7297
ttl = self._MODEL_LIST_CACHE_TTL_SECONDS
@@ -81,7 +106,7 @@ def _get_model_cache_ttl_seconds(self, umo: str | None = None) -> float:
81106
logger.debug(
82107
"读取 model_list_cache_ttl_seconds 失败,回退默认值 %.1f: %s",
83108
self._MODEL_LIST_CACHE_TTL_SECONDS,
84-
self._mask_sensitive_text(str(e)),
109+
redact_secrets(str(e)),
85110
)
86111
ttl = self._MODEL_LIST_CACHE_TTL_SECONDS
87112
return max(ttl, 0.0)
@@ -136,7 +161,7 @@ async def _test_provider_capability(self, provider):
136161
return True, None, None
137162
except Exception as e:
138163
err_code = "TEST_FAILED"
139-
err_reason = self._mask_sensitive_text(str(e))
164+
err_reason = redact_secrets(str(e))
140165
self._log_reachability_failure(
141166
provider, provider_capability_type, err_code, err_reason
142167
)
@@ -147,28 +172,45 @@ async def _find_provider_for_model(
147172
model_name: str,
148173
exclude_provider_id: str | None = None,
149174
umo: str | None = None,
150-
) -> Provider | None:
175+
) -> tuple[Provider | None, str | None]:
151176
"""在所有 LLM 提供商中查找包含指定模型的提供商。"""
152177
all_providers = [
153178
p
154179
for p in self.context.get_all_providers()
155180
if not exclude_provider_id or p.meta().id != exclude_provider_id
156181
]
157182
if not all_providers:
158-
return None
183+
return None, None
184+
185+
semaphore = asyncio.Semaphore(self._MODEL_LOOKUP_MAX_CONCURRENCY)
186+
187+
async def _fetch_models(
188+
provider: Provider,
189+
) -> tuple[Provider, list[str] | None, Exception | None]:
190+
async with semaphore:
191+
try:
192+
return (
193+
provider,
194+
await self._get_provider_models(provider, umo=umo),
195+
None,
196+
)
197+
except Exception as e:
198+
return provider, None, e
199+
200+
results = await asyncio.gather(
201+
*[_fetch_models(provider) for provider in all_providers]
202+
)
159203
failed_provider_errors: list[tuple[str, str]] = []
160-
for provider in all_providers:
204+
for provider, models, error in results:
161205
provider_id = provider.meta().id
162-
try:
163-
models = await self._get_provider_models(provider, umo=umo)
164-
except Exception as e:
165-
failed_provider_errors.append(
166-
(provider_id, self._format_safe_err("", e))
167-
)
206+
if error is not None:
207+
failed_provider_errors.append((provider_id, self._safe_err("", error)))
168208
continue
169-
170-
if model_name in models:
171-
return provider
209+
if models is None:
210+
continue
211+
matched_model_name = self._resolve_model_name(model_name, models)
212+
if matched_model_name is not None:
213+
return provider, matched_model_name
172214

173215
if failed_provider_errors and len(failed_provider_errors) == len(all_providers):
174216
failed_ids = ",".join(
@@ -190,7 +232,7 @@ async def _find_provider_for_model(
190232
for provider_id, error in failed_provider_errors
191233
),
192234
)
193-
return None
235+
return None, None
194236

195237
async def provider(
196238
self,
@@ -246,7 +288,7 @@ async def provider(
246288
p,
247289
None,
248290
reachable.__class__.__name__,
249-
self._mask_sensitive_text(str(reachable)),
291+
redact_secrets(str(reachable)),
250292
)
251293
reachable_flag = False
252294
error_code = reachable.__class__.__name__
@@ -386,33 +428,34 @@ async def _switch_model_by_name(
386428
try:
387429
models = await self._get_provider_models(prov, umo=umo)
388430
except Exception as e:
389-
err_msg = self._format_safe_err("", e)
431+
err_msg = self._safe_err("", e)
390432
logger.warning(
391433
"获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s",
392434
curr_provider_id,
393435
err_msg,
394436
)
395437
message.set_result(
396438
MessageEventResult().message(
397-
self._format_safe_err("获取当前提供商模型列表失败: ", e)
439+
self._safe_err("获取当前提供商模型列表失败: ", e)
398440
)
399441
)
400442
return
401443

402-
if model_name in models:
403-
self._set_model_and_invalidate(prov, model_name)
444+
matched_model_name = self._resolve_model_name(model_name, models)
445+
if matched_model_name is not None:
446+
self._update_provider_and_invalidate(prov, model_name=matched_model_name)
404447
message.set_result(
405448
MessageEventResult().message(
406-
f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{model_name}]",
449+
f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{matched_model_name}]",
407450
),
408451
)
409452
return
410453

411-
target_prov = await self._find_provider_for_model(
454+
target_prov, matched_target_model_name = await self._find_provider_for_model(
412455
model_name, exclude_provider_id=curr_provider_id, umo=umo
413456
)
414457

415-
if not target_prov:
458+
if target_prov is None or matched_target_model_name is None:
416459
message.set_result(
417460
MessageEventResult().message(
418461
f"模型 [{model_name}] 未在任何已配置的提供商中找到,或所有提供商模型列表获取失败,请检查配置或网络后重试。",
@@ -427,16 +470,18 @@ async def _switch_model_by_name(
427470
provider_type=ProviderType.CHAT_COMPLETION,
428471
umo=umo,
429472
)
430-
self._set_model_and_invalidate(target_prov, model_name)
473+
self._update_provider_and_invalidate(
474+
target_prov, model_name=matched_target_model_name
475+
)
431476
message.set_result(
432477
MessageEventResult().message(
433-
f"检测到模型 [{model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。",
478+
f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。",
434479
),
435480
)
436481
except Exception as e:
437482
message.set_result(
438483
MessageEventResult().message(
439-
self._format_safe_err("跨提供商切换并设置模型失败: ", e)
484+
self._safe_err("跨提供商切换并设置模型失败: ", e)
440485
),
441486
)
442487

@@ -462,7 +507,7 @@ async def model_ls(
462507
except Exception as e:
463508
message.set_result(
464509
MessageEventResult()
465-
.message(self._format_safe_err("获取模型列表失败: ", e))
510+
.message(self._safe_err("获取模型列表失败: ", e))
466511
.use_t2i(False),
467512
)
468513
return
@@ -487,7 +532,7 @@ async def model_ls(
487532
except Exception as e:
488533
message.set_result(
489534
MessageEventResult().message(
490-
self._format_safe_err("获取模型列表失败: ", e)
535+
self._safe_err("获取模型列表失败: ", e)
491536
),
492537
)
493538
return
@@ -496,7 +541,7 @@ async def model_ls(
496541
else:
497542
try:
498543
new_model = models[idx_or_name - 1]
499-
self._set_model_and_invalidate(prov, new_model)
544+
self._update_provider_and_invalidate(prov, model_name=new_model)
500545
message.set_result(
501546
MessageEventResult().message(
502547
f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]",
@@ -505,7 +550,7 @@ async def model_ls(
505550
except Exception as e:
506551
message.set_result(
507552
MessageEventResult().message(
508-
self._format_safe_err("切换模型未知错误: ", e)
553+
self._safe_err("切换模型未知错误: ", e)
509554
),
510555
)
511556
return
@@ -540,12 +585,12 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None
540585
else:
541586
try:
542587
new_key = keys_data[index - 1]
543-
self._set_key_and_invalidate(prov, new_key)
588+
self._update_provider_and_invalidate(prov, key=new_key)
544589
message.set_result(MessageEventResult().message("切换 Key 成功。"))
545590
except Exception as e:
546591
message.set_result(
547592
MessageEventResult().message(
548-
self._format_safe_err("切换 Key 未知错误: ", e)
593+
self._safe_err("切换 Key 未知错误: ", e)
549594
),
550595
)
551596
return

0 commit comments

Comments
 (0)