1616
1717_SECRET_PATTERNS = [
1818 re .compile (
19- r"(?i)\b(api_?key|access_?token|token|secret|auth_?token|session_?id|password)\s*=\s*[^&'\" ]+"
19+ r"(?i)\b(api_?key|key| access_?token|token|secret|auth_?token|session_?id|password)\s*=\s*[^&'\" ]+"
2020 ),
2121 re .compile (r"(?i)\bauthorization\s*:\s*bearer\s+[A-Za-z0-9._\-]+" ),
2222 re .compile (r"(?i)\bbearer\s+[A-Za-z0-9._\-]+" ),
@@ -39,6 +39,7 @@ class _ModelCacheEntry:
3939
4040class ProviderCommands :
4141 _MODEL_LIST_CACHE_TTL_SECONDS = 30.0
42+ _MODEL_LOOKUP_MAX_CONCURRENCY = 4
4243
4344 def __init__ (self , context : star .Context ) -> None :
4445 self .context = context
@@ -50,23 +51,47 @@ def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> N
5051 return
5152 self ._provider_models_cache .pop (provider_id , None )
5253
53- def _invalidate_cache_for (self , provider : Provider ) -> None :
54+ def _update_provider_and_invalidate (
55+ self ,
56+ provider : Provider ,
57+ * ,
58+ model_name : str | None = None ,
59+ key : str | None = None ,
60+ ) -> None :
61+ if model_name is not None :
62+ provider .set_model (model_name )
63+ if key is not None :
64+ provider .set_key (key )
5465 self ._invalidate_provider_models_cache (provider .meta ().id )
5566
56- def _set_model_and_invalidate (self , provider : Provider , model_name : str ) -> None :
57- provider .set_model (model_name )
58- self ._invalidate_cache_for (provider )
59-
60- def _set_key_and_invalidate (self , provider : Provider , key : str ) -> None :
61- provider .set_key (key )
62- self ._invalidate_cache_for (provider )
67+ @staticmethod
68+ def _safe_err (prefix : str , e : Exception ) -> str :
69+ return prefix + redact_secrets (str (e ))
6370
6471 @staticmethod
65- def _mask_sensitive_text ( value : str ) -> str :
66- return redact_secrets ( value )
72+ def _normalize_model_name ( model_name : str ) -> str :
73+ return model_name . strip (). casefold ( )
6774
68- def _format_safe_err (self , prefix : str , e : Exception ) -> str :
69- return f"{ prefix } { self ._mask_sensitive_text (str (e ))} "
75+ def _resolve_model_name (self , model_name : str , models : list [str ]) -> str | None :
76+ normalized_model_name = self ._normalize_model_name (model_name )
77+ if not normalized_model_name :
78+ return None
79+ if model_name in models :
80+ return model_name
81+
82+ for candidate in models :
83+ normalized_candidate = self ._normalize_model_name (candidate )
84+ if normalized_candidate == normalized_model_name :
85+ return candidate
86+ if normalized_candidate .endswith (
87+ f"/{ normalized_model_name } "
88+ ) or normalized_candidate .endswith (f":{ normalized_model_name } " ):
89+ return candidate
90+ if normalized_model_name .endswith (
91+ f"/{ normalized_candidate } "
92+ ) or normalized_model_name .endswith (f":{ normalized_candidate } " ):
93+ return candidate
94+ return None
7095
7196 def _get_model_cache_ttl_seconds (self , umo : str | None = None ) -> float :
7297 ttl = self ._MODEL_LIST_CACHE_TTL_SECONDS
@@ -81,7 +106,7 @@ def _get_model_cache_ttl_seconds(self, umo: str | None = None) -> float:
81106 logger .debug (
82107 "读取 model_list_cache_ttl_seconds 失败,回退默认值 %.1f: %s" ,
83108 self ._MODEL_LIST_CACHE_TTL_SECONDS ,
84- self . _mask_sensitive_text (str (e )),
109+ redact_secrets (str (e )),
85110 )
86111 ttl = self ._MODEL_LIST_CACHE_TTL_SECONDS
87112 return max (ttl , 0.0 )
@@ -136,7 +161,7 @@ async def _test_provider_capability(self, provider):
136161 return True , None , None
137162 except Exception as e :
138163 err_code = "TEST_FAILED"
139- err_reason = self . _mask_sensitive_text (str (e ))
164+ err_reason = redact_secrets (str (e ))
140165 self ._log_reachability_failure (
141166 provider , provider_capability_type , err_code , err_reason
142167 )
@@ -147,28 +172,45 @@ async def _find_provider_for_model(
147172 model_name : str ,
148173 exclude_provider_id : str | None = None ,
149174 umo : str | None = None ,
150- ) -> Provider | None :
175+ ) -> tuple [ Provider | None , str | None ] :
151176 """在所有 LLM 提供商中查找包含指定模型的提供商。"""
152177 all_providers = [
153178 p
154179 for p in self .context .get_all_providers ()
155180 if not exclude_provider_id or p .meta ().id != exclude_provider_id
156181 ]
157182 if not all_providers :
158- return None
183+ return None , None
184+
185+ semaphore = asyncio .Semaphore (self ._MODEL_LOOKUP_MAX_CONCURRENCY )
186+
187+ async def _fetch_models (
188+ provider : Provider ,
189+ ) -> tuple [Provider , list [str ] | None , Exception | None ]:
190+ async with semaphore :
191+ try :
192+ return (
193+ provider ,
194+ await self ._get_provider_models (provider , umo = umo ),
195+ None ,
196+ )
197+ except Exception as e :
198+ return provider , None , e
199+
200+ results = await asyncio .gather (
201+ * [_fetch_models (provider ) for provider in all_providers ]
202+ )
159203 failed_provider_errors : list [tuple [str , str ]] = []
160- for provider in all_providers :
204+ for provider , models , error in results :
161205 provider_id = provider .meta ().id
162- try :
163- models = await self ._get_provider_models (provider , umo = umo )
164- except Exception as e :
165- failed_provider_errors .append (
166- (provider_id , self ._format_safe_err ("" , e ))
167- )
206+ if error is not None :
207+ failed_provider_errors .append ((provider_id , self ._safe_err ("" , error )))
168208 continue
169-
170- if model_name in models :
171- return provider
209+ if models is None :
210+ continue
211+ matched_model_name = self ._resolve_model_name (model_name , models )
212+ if matched_model_name is not None :
213+ return provider , matched_model_name
172214
173215 if failed_provider_errors and len (failed_provider_errors ) == len (all_providers ):
174216 failed_ids = "," .join (
@@ -190,7 +232,7 @@ async def _find_provider_for_model(
190232 for provider_id , error in failed_provider_errors
191233 ),
192234 )
193- return None
235+ return None , None
194236
195237 async def provider (
196238 self ,
@@ -246,7 +288,7 @@ async def provider(
246288 p ,
247289 None ,
248290 reachable .__class__ .__name__ ,
249- self . _mask_sensitive_text (str (reachable )),
291+ redact_secrets (str (reachable )),
250292 )
251293 reachable_flag = False
252294 error_code = reachable .__class__ .__name__
@@ -386,33 +428,34 @@ async def _switch_model_by_name(
386428 try :
387429 models = await self ._get_provider_models (prov , umo = umo )
388430 except Exception as e :
389- err_msg = self ._format_safe_err ("" , e )
431+ err_msg = self ._safe_err ("" , e )
390432 logger .warning (
391433 "获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s" ,
392434 curr_provider_id ,
393435 err_msg ,
394436 )
395437 message .set_result (
396438 MessageEventResult ().message (
397- self ._format_safe_err ("获取当前提供商模型列表失败: " , e )
439+ self ._safe_err ("获取当前提供商模型列表失败: " , e )
398440 )
399441 )
400442 return
401443
402- if model_name in models :
403- self ._set_model_and_invalidate (prov , model_name )
444+ matched_model_name = self ._resolve_model_name (model_name , models )
445+ if matched_model_name is not None :
446+ self ._update_provider_and_invalidate (prov , model_name = matched_model_name )
404447 message .set_result (
405448 MessageEventResult ().message (
406- f"切换模型成功。当前提供商: [{ curr_provider_id } ] 当前模型: [{ model_name } ]" ,
449+ f"切换模型成功。当前提供商: [{ curr_provider_id } ] 当前模型: [{ matched_model_name } ]" ,
407450 ),
408451 )
409452 return
410453
411- target_prov = await self ._find_provider_for_model (
454+ target_prov , matched_target_model_name = await self ._find_provider_for_model (
412455 model_name , exclude_provider_id = curr_provider_id , umo = umo
413456 )
414457
415- if not target_prov :
458+ if target_prov is None or matched_target_model_name is None :
416459 message .set_result (
417460 MessageEventResult ().message (
418461 f"模型 [{ model_name } ] 未在任何已配置的提供商中找到,或所有提供商模型列表获取失败,请检查配置或网络后重试。" ,
@@ -427,16 +470,18 @@ async def _switch_model_by_name(
427470 provider_type = ProviderType .CHAT_COMPLETION ,
428471 umo = umo ,
429472 )
430- self ._set_model_and_invalidate (target_prov , model_name )
473+ self ._update_provider_and_invalidate (
474+ target_prov , model_name = matched_target_model_name
475+ )
431476 message .set_result (
432477 MessageEventResult ().message (
433- f"检测到模型 [{ model_name } ] 属于提供商 [{ target_id } ],已自动切换提供商并设置模型。" ,
478+ f"检测到模型 [{ matched_target_model_name } ] 属于提供商 [{ target_id } ],已自动切换提供商并设置模型。" ,
434479 ),
435480 )
436481 except Exception as e :
437482 message .set_result (
438483 MessageEventResult ().message (
439- self ._format_safe_err ("跨提供商切换并设置模型失败: " , e )
484+ self ._safe_err ("跨提供商切换并设置模型失败: " , e )
440485 ),
441486 )
442487
@@ -462,7 +507,7 @@ async def model_ls(
462507 except Exception as e :
463508 message .set_result (
464509 MessageEventResult ()
465- .message (self ._format_safe_err ("获取模型列表失败: " , e ))
510+ .message (self ._safe_err ("获取模型列表失败: " , e ))
466511 .use_t2i (False ),
467512 )
468513 return
@@ -487,7 +532,7 @@ async def model_ls(
487532 except Exception as e :
488533 message .set_result (
489534 MessageEventResult ().message (
490- self ._format_safe_err ("获取模型列表失败: " , e )
535+ self ._safe_err ("获取模型列表失败: " , e )
491536 ),
492537 )
493538 return
@@ -496,7 +541,7 @@ async def model_ls(
496541 else :
497542 try :
498543 new_model = models [idx_or_name - 1 ]
499- self ._set_model_and_invalidate (prov , new_model )
544+ self ._update_provider_and_invalidate (prov , model_name = new_model )
500545 message .set_result (
501546 MessageEventResult ().message (
502547 f"切换模型成功。当前提供商: [{ prov .meta ().id } ] 当前模型: [{ prov .get_model ()} ]" ,
@@ -505,7 +550,7 @@ async def model_ls(
505550 except Exception as e :
506551 message .set_result (
507552 MessageEventResult ().message (
508- self ._format_safe_err ("切换模型未知错误: " , e )
553+ self ._safe_err ("切换模型未知错误: " , e )
509554 ),
510555 )
511556 return
@@ -540,12 +585,12 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None
540585 else :
541586 try :
542587 new_key = keys_data [index - 1 ]
543- self ._set_key_and_invalidate (prov , new_key )
588+ self ._update_provider_and_invalidate (prov , key = new_key )
544589 message .set_result (MessageEventResult ().message ("切换 Key 成功。" ))
545590 except Exception as e :
546591 message .set_result (
547592 MessageEventResult ().message (
548- self ._format_safe_err ("切换 Key 未知错误: " , e )
593+ self ._safe_err ("切换 Key 未知错误: " , e )
549594 ),
550595 )
551596 return
0 commit comments