Skip to content

Commit 086a11e

Browse files
committed
refactor: scope provider model cache by session
1 parent f139be7 commit 086a11e

1 file changed

Lines changed: 86 additions & 48 deletions

File tree

  • astrbot/builtin_stars/builtin_commands/commands

astrbot/builtin_stars/builtin_commands/commands/provider.py

Lines changed: 86 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import asyncio
44
import time
5-
from collections.abc import Sequence
6-
from typing import TYPE_CHECKING
5+
from collections.abc import Callable, Sequence
6+
from typing import TYPE_CHECKING, TypeVar
77

88
from astrbot import logger
99
from astrbot.api import star
@@ -21,11 +21,15 @@
2121
MODEL_LIST_CACHE_TTL_KEY = "model_list_cache_ttl_seconds"
2222
MODEL_LOOKUP_MAX_CONCURRENCY_KEY = "model_lookup_max_concurrency"
2323

24+
T = TypeVar("T")
25+
2426

2527
class ProviderCommands:
2628
def __init__(self, context: star.Context) -> None:
2729
self.context = context
28-
self._model_cache: dict[str, tuple[float, list[str]]] = {}
30+
self._model_cache: dict[
31+
tuple[str, str | None], tuple[float, tuple[str, ...]]
32+
] = {}
2933
self._register_provider_change_hook()
3034

3135
def _register_provider_change_hook(self) -> None:
@@ -45,12 +49,21 @@ def _register_provider_change_hook(self) -> None:
4549
if callable(register_change_hook):
4650
register_change_hook(self._on_provider_manager_changed)
4751

48-
def invalidate_provider_models_cache(self, provider_id: str | None = None) -> None:
52+
def invalidate_provider_models_cache(
53+
self, provider_id: str | None = None, *, umo: str | None = None
54+
) -> None:
4955
"""Public hook for cache invalidation on external provider config changes."""
5056
if provider_id is None:
5157
self._model_cache.clear()
5258
return
53-
self._model_cache.pop(provider_id, None)
59+
if umo is not None:
60+
self._model_cache.pop((provider_id, umo), None)
61+
return
62+
stale_keys = [
63+
cache_key for cache_key in self._model_cache if cache_key[0] == provider_id
64+
]
65+
for cache_key in stale_keys:
66+
self._model_cache.pop(cache_key, None)
5467

5568
def _on_provider_manager_changed(
5669
self,
@@ -59,60 +72,73 @@ def _on_provider_manager_changed(
5972
umo: str | None,
6073
) -> None:
6174
if provider_type == ProviderType.CHAT_COMPLETION:
62-
self.invalidate_provider_models_cache(provider_id)
75+
self.invalidate_provider_models_cache(provider_id, umo=umo)
76+
77+
@staticmethod
78+
def _cache_key(provider_id: str, umo: str | None) -> tuple[str, str | None]:
79+
return provider_id, umo
6380

6481
def _get_cached_models(
65-
self, provider_id: str, *, ttl_seconds: float
82+
self, provider_id: str, *, ttl_seconds: float, umo: str | None
6683
) -> list[str] | None:
6784
if ttl_seconds <= 0:
6885
return None
69-
entry = self._model_cache.get(provider_id)
86+
entry = self._model_cache.get(self._cache_key(provider_id, umo))
7087
if not entry:
7188
return None
7289
timestamp, models = entry
7390
if time.monotonic() - timestamp > ttl_seconds:
74-
self._model_cache.pop(provider_id, None)
91+
self._model_cache.pop(self._cache_key(provider_id, umo), None)
7592
return None
7693
return list(models)
7794

78-
def _set_cached_models(self, provider_id: str, models: list[str]) -> None:
79-
self._model_cache[provider_id] = (time.monotonic(), list(models))
95+
def _set_cached_models(
96+
self, provider_id: str, models: list[str], *, umo: str | None
97+
) -> None:
98+
self._model_cache[self._cache_key(provider_id, umo)] = (
99+
time.monotonic(),
100+
tuple(models),
101+
)
80102

81-
def _get_ttl_setting(self, umo: str | None) -> float:
103+
def _get_provider_setting(
104+
self,
105+
umo: str | None,
106+
key: str,
107+
default: T,
108+
cast: Callable[[object], T],
109+
) -> T:
82110
if not umo:
83-
return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT
111+
return default
84112
try:
85113
cfg = self.context.get_config(umo).get("provider_settings", {})
86-
raw = cfg.get(MODEL_LIST_CACHE_TTL_KEY)
114+
raw = cfg.get(key)
87115
if raw is None:
88-
return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT
89-
return float(raw)
116+
return default
117+
return cast(raw)
90118
except Exception as e:
91119
logger.debug(
92120
"读取 %s 失败,回退默认值 %r: %s",
93-
MODEL_LIST_CACHE_TTL_KEY,
94-
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT,
121+
key,
122+
default,
95123
safe_error("", e),
96124
)
97-
return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT
125+
return default
126+
127+
def _get_ttl_setting(self, umo: str | None) -> float:
128+
return self._get_provider_setting(
129+
umo,
130+
MODEL_LIST_CACHE_TTL_KEY,
131+
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT,
132+
float,
133+
)
98134

99135
def _get_lookup_concurrency(self, umo: str | None) -> int:
100-
if not umo:
101-
return MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT
102-
try:
103-
cfg = self.context.get_config(umo).get("provider_settings", {})
104-
raw = cfg.get(MODEL_LOOKUP_MAX_CONCURRENCY_KEY)
105-
if raw is None:
106-
return MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT
107-
return int(raw)
108-
except Exception as e:
109-
logger.debug(
110-
"读取 %s 失败,回退默认值 %r: %s",
111-
MODEL_LOOKUP_MAX_CONCURRENCY_KEY,
112-
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT,
113-
safe_error("", e),
114-
)
115-
return MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT
136+
return self._get_provider_setting(
137+
umo,
138+
MODEL_LOOKUP_MAX_CONCURRENCY_KEY,
139+
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT,
140+
int,
141+
)
116142

117143
def _resolve_model_name(
118144
self,
@@ -135,18 +161,20 @@ def _resolve_model_name(
135161

136162
# provider-qualified suffix match:
137163
# e.g. candidate `openai/gpt-4o` should match requested `gpt-4o`.
138-
def _match_qualified_suffix(req: str, cand: str) -> bool:
139-
return cand.endswith(f"/{req}") or cand.endswith(f":{req}")
140-
141164
for candidate in models:
142-
if _match_qualified_suffix(requested_norm, candidate.casefold()):
165+
cand_norm = candidate.casefold()
166+
if cand_norm.endswith(f"/{requested_norm}") or cand_norm.endswith(
167+
f":{requested_norm}"
168+
):
143169
return candidate
144170

145171
return None
146172

147-
def _apply_model(self, prov: Provider, model_name: str) -> str:
173+
def _apply_model(
174+
self, prov: Provider, model_name: str, *, umo: str | None = None
175+
) -> str:
148176
prov.set_model(model_name)
149-
self.invalidate_provider_models_cache(prov.meta().id)
177+
self.invalidate_provider_models_cache(prov.meta().id, umo=umo)
150178
return f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]"
151179

152180
async def _get_provider_models(
@@ -159,13 +187,17 @@ async def _get_provider_models(
159187
provider_id = provider.meta().id
160188
ttl_seconds = max(float(self._get_ttl_setting(umo)), 0.0)
161189
if use_cache:
162-
cached = self._get_cached_models(provider_id, ttl_seconds=ttl_seconds)
190+
cached = self._get_cached_models(
191+
provider_id,
192+
ttl_seconds=ttl_seconds,
193+
umo=umo,
194+
)
163195
if cached is not None:
164196
return cached
165197

166198
models = list(await provider.get_models())
167199
if use_cache and ttl_seconds > 0:
168-
self._set_cached_models(provider_id, models)
200+
self._set_cached_models(provider_id, models, umo=umo)
169201
return models
170202

171203
def _log_reachability_failure(
@@ -484,7 +516,7 @@ async def _switch_model_by_name(
484516
if matched_model_name is not None:
485517
message.set_result(
486518
MessageEventResult().message(
487-
self._apply_model(prov, matched_model_name)
519+
self._apply_model(prov, matched_model_name, umo=umo)
488520
),
489521
)
490522
return
@@ -508,8 +540,7 @@ async def _switch_model_by_name(
508540
provider_type=ProviderType.CHAT_COMPLETION,
509541
umo=umo,
510542
)
511-
target_prov.set_model(matched_target_model_name)
512-
self.invalidate_provider_models_cache(target_prov.meta().id)
543+
self._apply_model(target_prov, matched_target_model_name, umo=umo)
513544
message.set_result(
514545
MessageEventResult().message(
515546
f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。",
@@ -584,7 +615,11 @@ async def model_ls(
584615
new_model = models[idx_or_name - 1]
585616
message.set_result(
586617
MessageEventResult().message(
587-
self._apply_model(prov, new_model)
618+
self._apply_model(
619+
prov,
620+
new_model,
621+
umo=message.unified_msg_origin,
622+
)
588623
),
589624
)
590625
except Exception as e:
@@ -626,7 +661,10 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None
626661
try:
627662
new_key = keys_data[index - 1]
628663
prov.set_key(new_key)
629-
self.invalidate_provider_models_cache(prov.meta().id)
664+
self.invalidate_provider_models_cache(
665+
prov.meta().id,
666+
umo=message.unified_msg_origin,
667+
)
630668
message.set_result(MessageEventResult().message("切换 Key 成功。"))
631669
except Exception as e:
632670
message.set_result(

0 commit comments

Comments
 (0)