22
33import asyncio
44import time
5- from collections .abc import Sequence
6- from typing import TYPE_CHECKING
5+ from collections .abc import Callable , Sequence
6+ from typing import TYPE_CHECKING , TypeVar
77
88from astrbot import logger
99from astrbot .api import star
2121MODEL_LIST_CACHE_TTL_KEY = "model_list_cache_ttl_seconds"
2222MODEL_LOOKUP_MAX_CONCURRENCY_KEY = "model_lookup_max_concurrency"
2323
24+ T = TypeVar ("T" )
25+
2426
2527class ProviderCommands :
2628 def __init__ (self , context : star .Context ) -> None :
2729 self .context = context
28- self ._model_cache : dict [str , tuple [float , list [str ]]] = {}
30+ self ._model_cache : dict [
31+ tuple [str , str | None ], tuple [float , tuple [str , ...]]
32+ ] = {}
2933 self ._register_provider_change_hook ()
3034
3135 def _register_provider_change_hook (self ) -> None :
@@ -45,12 +49,21 @@ def _register_provider_change_hook(self) -> None:
4549 if callable (register_change_hook ):
4650 register_change_hook (self ._on_provider_manager_changed )
4751
48- def invalidate_provider_models_cache (self , provider_id : str | None = None ) -> None :
52+ def invalidate_provider_models_cache (
53+ self , provider_id : str | None = None , * , umo : str | None = None
54+ ) -> None :
4955 """Public hook for cache invalidation on external provider config changes."""
5056 if provider_id is None :
5157 self ._model_cache .clear ()
5258 return
53- self ._model_cache .pop (provider_id , None )
59+ if umo is not None :
60+ self ._model_cache .pop ((provider_id , umo ), None )
61+ return
62+ stale_keys = [
63+ cache_key for cache_key in self ._model_cache if cache_key [0 ] == provider_id
64+ ]
65+ for cache_key in stale_keys :
66+ self ._model_cache .pop (cache_key , None )
5467
5568 def _on_provider_manager_changed (
5669 self ,
@@ -59,60 +72,73 @@ def _on_provider_manager_changed(
5972 umo : str | None ,
6073 ) -> None :
6174 if provider_type == ProviderType .CHAT_COMPLETION :
62- self .invalidate_provider_models_cache (provider_id )
75+ self .invalidate_provider_models_cache (provider_id , umo = umo )
76+
77+ @staticmethod
78+ def _cache_key (provider_id : str , umo : str | None ) -> tuple [str , str | None ]:
79+ return provider_id , umo
6380
6481 def _get_cached_models (
65- self , provider_id : str , * , ttl_seconds : float
82+ self , provider_id : str , * , ttl_seconds : float , umo : str | None
6683 ) -> list [str ] | None :
6784 if ttl_seconds <= 0 :
6885 return None
69- entry = self ._model_cache .get (provider_id )
86+ entry = self ._model_cache .get (self . _cache_key ( provider_id , umo ) )
7087 if not entry :
7188 return None
7289 timestamp , models = entry
7390 if time .monotonic () - timestamp > ttl_seconds :
74- self ._model_cache .pop (provider_id , None )
91+ self ._model_cache .pop (self . _cache_key ( provider_id , umo ) , None )
7592 return None
7693 return list (models )
7794
78- def _set_cached_models (self , provider_id : str , models : list [str ]) -> None :
79- self ._model_cache [provider_id ] = (time .monotonic (), list (models ))
95+ def _set_cached_models (
96+ self , provider_id : str , models : list [str ], * , umo : str | None
97+ ) -> None :
98+ self ._model_cache [self ._cache_key (provider_id , umo )] = (
99+ time .monotonic (),
100+ tuple (models ),
101+ )
80102
81- def _get_ttl_setting (self , umo : str | None ) -> float :
103+ def _get_provider_setting (
104+ self ,
105+ umo : str | None ,
106+ key : str ,
107+ default : T ,
108+ cast : Callable [[object ], T ],
109+ ) -> T :
82110 if not umo :
83- return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT
111+ return default
84112 try :
85113 cfg = self .context .get_config (umo ).get ("provider_settings" , {})
86- raw = cfg .get (MODEL_LIST_CACHE_TTL_KEY )
114+ raw = cfg .get (key )
87115 if raw is None :
88- return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT
89- return float (raw )
116+ return default
117+ return cast (raw )
90118 except Exception as e :
91119 logger .debug (
92120 "读取 %s 失败,回退默认值 %r: %s" ,
93- MODEL_LIST_CACHE_TTL_KEY ,
94- MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT ,
121+ key ,
122+ default ,
95123 safe_error ("" , e ),
96124 )
97- return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT
125+ return default
126+
127+ def _get_ttl_setting (self , umo : str | None ) -> float :
128+ return self ._get_provider_setting (
129+ umo ,
130+ MODEL_LIST_CACHE_TTL_KEY ,
131+ MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT ,
132+ float ,
133+ )
98134
99135 def _get_lookup_concurrency (self , umo : str | None ) -> int :
100- if not umo :
101- return MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT
102- try :
103- cfg = self .context .get_config (umo ).get ("provider_settings" , {})
104- raw = cfg .get (MODEL_LOOKUP_MAX_CONCURRENCY_KEY )
105- if raw is None :
106- return MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT
107- return int (raw )
108- except Exception as e :
109- logger .debug (
110- "读取 %s 失败,回退默认值 %r: %s" ,
111- MODEL_LOOKUP_MAX_CONCURRENCY_KEY ,
112- MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT ,
113- safe_error ("" , e ),
114- )
115- return MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT
136+ return self ._get_provider_setting (
137+ umo ,
138+ MODEL_LOOKUP_MAX_CONCURRENCY_KEY ,
139+ MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT ,
140+ int ,
141+ )
116142
117143 def _resolve_model_name (
118144 self ,
@@ -135,18 +161,20 @@ def _resolve_model_name(
135161
136162 # provider-qualified suffix match:
137163 # e.g. candidate `openai/gpt-4o` should match requested `gpt-4o`.
138- def _match_qualified_suffix (req : str , cand : str ) -> bool :
139- return cand .endswith (f"/{ req } " ) or cand .endswith (f":{ req } " )
140-
141164 for candidate in models :
142- if _match_qualified_suffix (requested_norm , candidate .casefold ()):
165+ cand_norm = candidate .casefold ()
166+ if cand_norm .endswith (f"/{ requested_norm } " ) or cand_norm .endswith (
167+ f":{ requested_norm } "
168+ ):
143169 return candidate
144170
145171 return None
146172
147- def _apply_model (self , prov : Provider , model_name : str ) -> str :
173+ def _apply_model (
174+ self , prov : Provider , model_name : str , * , umo : str | None = None
175+ ) -> str :
148176 prov .set_model (model_name )
149- self .invalidate_provider_models_cache (prov .meta ().id )
177+ self .invalidate_provider_models_cache (prov .meta ().id , umo = umo )
150178 return f"切换模型成功。当前提供商: [{ prov .meta ().id } ] 当前模型: [{ prov .get_model ()} ]"
151179
152180 async def _get_provider_models (
@@ -159,13 +187,17 @@ async def _get_provider_models(
159187 provider_id = provider .meta ().id
160188 ttl_seconds = max (float (self ._get_ttl_setting (umo )), 0.0 )
161189 if use_cache :
162- cached = self ._get_cached_models (provider_id , ttl_seconds = ttl_seconds )
190+ cached = self ._get_cached_models (
191+ provider_id ,
192+ ttl_seconds = ttl_seconds ,
193+ umo = umo ,
194+ )
163195 if cached is not None :
164196 return cached
165197
166198 models = list (await provider .get_models ())
167199 if use_cache and ttl_seconds > 0 :
168- self ._set_cached_models (provider_id , models )
200+ self ._set_cached_models (provider_id , models , umo = umo )
169201 return models
170202
171203 def _log_reachability_failure (
@@ -484,7 +516,7 @@ async def _switch_model_by_name(
484516 if matched_model_name is not None :
485517 message .set_result (
486518 MessageEventResult ().message (
487- self ._apply_model (prov , matched_model_name )
519+ self ._apply_model (prov , matched_model_name , umo = umo )
488520 ),
489521 )
490522 return
@@ -508,8 +540,7 @@ async def _switch_model_by_name(
508540 provider_type = ProviderType .CHAT_COMPLETION ,
509541 umo = umo ,
510542 )
511- target_prov .set_model (matched_target_model_name )
512- self .invalidate_provider_models_cache (target_prov .meta ().id )
543+ self ._apply_model (target_prov , matched_target_model_name , umo = umo )
513544 message .set_result (
514545 MessageEventResult ().message (
515546 f"检测到模型 [{ matched_target_model_name } ] 属于提供商 [{ target_id } ],已自动切换提供商并设置模型。" ,
@@ -584,7 +615,11 @@ async def model_ls(
584615 new_model = models [idx_or_name - 1 ]
585616 message .set_result (
586617 MessageEventResult ().message (
587- self ._apply_model (prov , new_model )
618+ self ._apply_model (
619+ prov ,
620+ new_model ,
621+ umo = message .unified_msg_origin ,
622+ )
588623 ),
589624 )
590625 except Exception as e :
@@ -626,7 +661,10 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None
626661 try :
627662 new_key = keys_data [index - 1 ]
628663 prov .set_key (new_key )
629- self .invalidate_provider_models_cache (prov .meta ().id )
664+ self .invalidate_provider_models_cache (
665+ prov .meta ().id ,
666+ umo = message .unified_msg_origin ,
667+ )
630668 message .set_result (MessageEventResult ().message ("切换 Key 成功。" ))
631669 except Exception as e :
632670 message .set_result (
0 commit comments