@@ -31,13 +31,27 @@ def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> N
3131 return
3232 self ._provider_models_cache .pop (provider_id , None )
3333
34+ def _invalidate_cache_for (self , provider : "Provider" ) -> None :
35+ self ._invalidate_provider_models_cache (provider .meta ().id )
36+
37+ def _set_model_and_invalidate (self , provider : "Provider" , model_name : str ) -> None :
38+ provider .set_model (model_name )
39+ self ._invalidate_cache_for (provider )
40+
41+ def _set_key_and_invalidate (self , provider : "Provider" , key : str ) -> None :
42+ provider .set_key (key )
43+ self ._invalidate_cache_for (provider )
44+
3445 @staticmethod
3546 def _mask_sensitive_text (value : str ) -> str :
3647 return _API_KEY_PATTERN .sub ("key=***" , value )
3748
3849 def _safe_err (self , e : BaseException ) -> str :
3950 return self ._mask_sensitive_text (str (e ))
4051
52+ def _format_err (self , prefix : str , e : BaseException ) -> str :
53+ return f"{ prefix } { self ._safe_err (e )} "
54+
4155 async def _get_provider_models (
4256 self , provider : "Provider" , * , use_cache : bool = True
4357 ) -> list [str ]:
@@ -96,18 +110,38 @@ async def _find_provider_for_model(
96110 ]
97111 if not all_providers :
98112 return None
99- results = await asyncio .gather (
100- * [self ._get_provider_models (p ) for p in all_providers ],
101- return_exceptions = True ,
102- )
113+
114+ async def _fetch_models (
115+ provider : "Provider" ,
116+ ) -> tuple ["Provider" , list [str ] | None , BaseException | None ]:
117+ try :
118+ return provider , await self ._get_provider_models (provider ), None
119+ except BaseException as e :
120+ return provider , None , e
121+
122+ tasks = [
123+ asyncio .create_task (_fetch_models (provider )) for provider in all_providers
124+ ]
103125 failed_provider_errors : list [tuple [str , str ]] = []
104- for provider , result in zip (all_providers , results ):
105- if isinstance (result , BaseException ):
106- masked_error = self ._safe_err (result )
126+ matched_provider : Provider | None = None
127+ for task in asyncio .as_completed (tasks ):
128+ provider , models , error = await task
129+ if error is not None :
130+ masked_error = self ._safe_err (error )
107131 failed_provider_errors .append ((provider .meta ().id , masked_error ))
108132 continue
109- if model_name in result :
110- return provider
133+
134+ if models is not None and model_name in models :
135+ matched_provider = provider
136+ break
137+
138+ if matched_provider is not None :
139+ for task in tasks :
140+ if not task .done ():
141+ task .cancel ()
142+ await asyncio .gather (* tasks , return_exceptions = True )
143+ return matched_provider
144+
111145 if failed_provider_errors and len (failed_provider_errors ) == len (all_providers ):
112146 failed_ids = "," .join (
113147 provider_id for provider_id , _ in failed_provider_errors
@@ -334,13 +368,14 @@ async def _switch_model_by_name(
334368 err_msg ,
335369 )
336370 message .set_result (
337- MessageEventResult ().message ("获取当前提供商模型列表失败: " + err_msg ),
371+ MessageEventResult ().message (
372+ self ._format_err ("获取当前提供商模型列表失败: " , e )
373+ )
338374 )
339375 return
340376
341377 if model_name in models :
342- prov .set_model (model_name )
343- self ._invalidate_provider_models_cache (curr_provider_id )
378+ self ._set_model_and_invalidate (prov , model_name )
344379 message .set_result (
345380 MessageEventResult ().message (
346381 f"切换模型成功。当前提供商: [{ curr_provider_id } ] 当前模型: [{ model_name } ]" ,
@@ -375,17 +410,17 @@ async def _switch_model_by_name(
375410 provider_type = ProviderType .CHAT_COMPLETION ,
376411 umo = umo ,
377412 )
378- target_prov .set_model (model_name )
379- self ._invalidate_provider_models_cache (target_id )
413+ self ._set_model_and_invalidate (target_prov , model_name )
380414 message .set_result (
381415 MessageEventResult ().message (
382416 f"检测到模型 [{ model_name } ] 属于提供商 [{ target_id } ],已自动切换提供商并设置模型。" ,
383417 ),
384418 )
385419 except BaseException as e :
386- err_msg = self ._safe_err (e )
387420 message .set_result (
388- MessageEventResult ().message ("跨提供商切换并设置模型失败: " + err_msg ),
421+ MessageEventResult ().message (
422+ self ._format_err ("跨提供商切换并设置模型失败: " , e )
423+ ),
389424 )
390425
391426 async def model_ls (
@@ -406,10 +441,9 @@ async def model_ls(
406441 try :
407442 models = await self ._get_provider_models (prov )
408443 except BaseException as e :
409- err_msg = self ._safe_err (e )
410444 message .set_result (
411445 MessageEventResult ()
412- .message ("获取模型列表失败: " + err_msg )
446+ .message (self . _format_err ( "获取模型列表失败: " , e ) )
413447 .use_t2i (False ),
414448 )
415449 return
@@ -430,27 +464,28 @@ async def model_ls(
430464 try :
431465 models = await self ._get_provider_models (prov )
432466 except BaseException as e :
433- err_msg = self ._safe_err (e )
434467 message .set_result (
435- MessageEventResult ().message ("获取模型列表失败: " + err_msg ),
468+ MessageEventResult ().message (
469+ self ._format_err ("获取模型列表失败: " , e )
470+ ),
436471 )
437472 return
438473 if idx_or_name > len (models ) or idx_or_name < 1 :
439474 message .set_result (MessageEventResult ().message ("模型序号错误。" ))
440475 else :
441476 try :
442477 new_model = models [idx_or_name - 1 ]
443- prov .set_model (new_model )
444- self ._invalidate_provider_models_cache (prov .meta ().id )
478+ self ._set_model_and_invalidate (prov , new_model )
445479 message .set_result (
446480 MessageEventResult ().message (
447481 f"切换模型成功。当前提供商: [{ prov .meta ().id } ] 当前模型: [{ prov .get_model ()} ]" ,
448482 ),
449483 )
450484 except BaseException as e :
451- err_msg = self ._safe_err (e )
452485 message .set_result (
453- MessageEventResult ().message ("切换模型未知错误: " + err_msg ),
486+ MessageEventResult ().message (
487+ self ._format_err ("切换模型未知错误: " , e )
488+ ),
454489 )
455490 return
456491 else :
@@ -484,10 +519,11 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None
484519 else :
485520 try :
486521 new_key = keys_data [index - 1 ]
487- prov .set_key (new_key )
488- self ._invalidate_provider_models_cache (prov .meta ().id )
522+ self ._set_key_and_invalidate (prov , new_key )
489523 except BaseException as e :
490524 message .set_result (
491- MessageEventResult ().message (f"切换 Key 未知错误: { e !s} " ),
525+ MessageEventResult ().message (
526+ self ._format_err ("切换 Key 未知错误: " , e )
527+ ),
492528 )
493529 message .set_result (MessageEventResult ().message ("切换 Key 成功。" ))
0 commit comments