1212RuntimeAuthMode = Literal ["auto" , "none" , "api_key" , "jwt" ]
1313
1414_TokenKey = tuple [str , str , str ]
15- _LockKey = tuple [str , str , str , int ]
15+ _LockKey = tuple [str , str , str , asyncio . AbstractEventLoop ]
1616_DEFAULT_MAX_CACHE_ENTRIES = 256
1717_DEFAULT_JWT_UNAVAILABLE_TTL_SECONDS = 30
1818
@@ -37,12 +37,21 @@ def is_fresh(self, *, refresh_margin_seconds: int) -> bool:
3737class RuntimeTokenCache :
3838 """Thread-safe runtime token cache keyed by server and target."""
3939
40- def __init__ (self , * , max_entries : int = _DEFAULT_MAX_CACHE_ENTRIES ) -> None :
40+ def __init__ (
41+ self ,
42+ * ,
43+ max_entries : int = _DEFAULT_MAX_CACHE_ENTRIES ,
44+ jwt_unavailable_ttl_seconds : int = _DEFAULT_JWT_UNAVAILABLE_TTL_SECONDS ,
45+ ) -> None :
4146 if max_entries < 1 :
4247 raise ValueError ("max_entries must be >= 1." )
48+ if jwt_unavailable_ttl_seconds < 0 :
49+ raise ValueError ("jwt_unavailable_ttl_seconds must be >= 0." )
4350 self ._max_entries = max_entries
51+ self ._jwt_unavailable_ttl_seconds = jwt_unavailable_ttl_seconds
4452 self ._tokens : dict [_TokenKey , RuntimeToken ] = {}
45- self ._jwt_unavailable = False
53+ self ._jwt_unavailable_until : datetime | None = None
54+ self ._jwt_unavailable_servers : dict [str , datetime ] = {}
4655 self ._jwt_unavailable_targets : dict [_TokenKey , datetime ] = {}
4756 self ._exchange_locks : dict [_LockKey , asyncio .Lock ] = {}
4857 self ._lock = threading .Lock ()
@@ -75,6 +84,7 @@ def set(self, token: RuntimeToken) -> None:
7584 self ._tokens .pop (oldest_key , None )
7685 self ._jwt_unavailable_targets .pop (oldest_key , None )
7786 self ._tokens [key ] = token
87+ self ._jwt_unavailable_servers .pop (token .server_url , None )
7888 self ._jwt_unavailable_targets .pop (key , None )
7989
8090 def remove (self , server_url : str , target_type : str , target_id : str ) -> None :
@@ -89,13 +99,30 @@ def mark_jwt_unavailable(
8999 target_type : str | None = None ,
90100 target_id : str | None = None ,
91101 globally : bool = False ,
92- ttl_seconds : int = _DEFAULT_JWT_UNAVAILABLE_TTL_SECONDS ,
102+ ttl_seconds : int | None = None ,
93103 ) -> None :
94104 """Record that JWT runtime auth should not be attempted."""
105+ ttl = self ._jwt_unavailable_ttl_seconds if ttl_seconds is None else ttl_seconds
106+ if ttl < 0 :
107+ raise ValueError ("ttl_seconds must be >= 0." )
108+ expires_at = datetime .now (UTC ) + timedelta (seconds = ttl )
95109 with self ._lock :
110+ self ._prune_expired_unavailable_markers_locked ()
96111 if globally :
97- self ._jwt_unavailable = True
98- self ._tokens .clear ()
112+ if server_url is None :
113+ self ._jwt_unavailable_until = expires_at
114+ self ._tokens .clear ()
115+ self ._jwt_unavailable_servers .clear ()
116+ self ._jwt_unavailable_targets .clear ()
117+ return
118+
119+ if (
120+ server_url not in self ._jwt_unavailable_servers
121+ and len (self ._jwt_unavailable_servers ) >= self ._max_entries
122+ ):
123+ self ._jwt_unavailable_servers .pop (next (iter (self ._jwt_unavailable_servers )))
124+ self ._jwt_unavailable_servers [server_url ] = expires_at
125+ self ._drop_server_entries_locked (server_url )
99126 return
100127 if server_url is not None and target_type is not None and target_id is not None :
101128 key = (server_url , target_type , target_id )
@@ -104,42 +131,70 @@ def mark_jwt_unavailable(
104131 and len (self ._jwt_unavailable_targets ) >= self ._max_entries
105132 ):
106133 self ._jwt_unavailable_targets .pop (next (iter (self ._jwt_unavailable_targets )))
107- self ._jwt_unavailable_targets [key ] = datetime .now (UTC ) + timedelta (
108- seconds = ttl_seconds
109- )
134+ self ._jwt_unavailable_targets [key ] = expires_at
110135 self ._tokens .pop (key , None )
111136
112137 def is_jwt_unavailable (self , server_url : str , target_type : str , target_id : str ) -> bool :
113138 """Return whether JWT exchange is known unavailable for the target."""
114139 key = (server_url , target_type , target_id )
115140 with self ._lock :
116- if self ._jwt_unavailable :
141+ self ._prune_expired_unavailable_markers_locked ()
142+ if self ._jwt_unavailable_until is not None :
117143 return True
118- expires_at = self ._jwt_unavailable_targets .get (key )
119- if expires_at is None :
120- return False
121- if expires_at > datetime .now (UTC ):
144+ if server_url in self ._jwt_unavailable_servers :
122145 return True
123- self ._jwt_unavailable_targets .pop (key , None )
124- return False
146+ expires_at = self ._jwt_unavailable_targets .get (key )
147+ return expires_at is not None
125148
126149 def clear (self ) -> None :
127150 """Clear every cached token and fallback marker."""
128151 with self ._lock :
129152 self ._tokens .clear ()
130- self ._jwt_unavailable = False
153+ self ._jwt_unavailable_until = None
154+ self ._jwt_unavailable_servers .clear ()
131155 self ._jwt_unavailable_targets .clear ()
132156
133157 def exchange_lock (self , server_url : str , target_type : str , target_id : str ) -> asyncio .Lock :
134158 """Return the async exchange lock for one server and target."""
135- key = (server_url , target_type , target_id , id ( asyncio .get_running_loop () ))
159+ key = (server_url , target_type , target_id , asyncio .get_running_loop ())
136160 with self ._lock :
137161 lock = self ._exchange_locks .get (key )
138162 if lock is None :
163+ if len (self ._exchange_locks ) >= self ._max_entries :
164+ self ._evict_idle_exchange_lock_locked ()
139165 lock = asyncio .Lock ()
140166 self ._exchange_locks [key ] = lock
167+ else :
168+ # Preserve insertion order as a simple LRU for idle-lock eviction.
169+ self ._exchange_locks .pop (key )
170+ self ._exchange_locks [key ] = lock
141171 return lock
142172
173+ def _drop_server_entries_locked (self , server_url : str ) -> None :
174+ for key in list (self ._tokens ):
175+ if key [0 ] == server_url :
176+ self ._tokens .pop (key , None )
177+ for key in list (self ._jwt_unavailable_targets ):
178+ if key [0 ] == server_url :
179+ self ._jwt_unavailable_targets .pop (key , None )
180+
181+ def _prune_expired_unavailable_markers_locked (self ) -> None :
182+ now = datetime .now (UTC )
183+ if self ._jwt_unavailable_until is not None and self ._jwt_unavailable_until <= now :
184+ self ._jwt_unavailable_until = None
185+ for server_url , expires_at in list (self ._jwt_unavailable_servers .items ()):
186+ if expires_at <= now :
187+ self ._jwt_unavailable_servers .pop (server_url , None )
188+ for key , expires_at in list (self ._jwt_unavailable_targets .items ()):
189+ if expires_at <= now :
190+ self ._jwt_unavailable_targets .pop (key , None )
191+
192+ def _evict_idle_exchange_lock_locked (self ) -> None :
193+ for key , lock in list (self ._exchange_locks .items ()):
194+ if not lock .locked ():
195+ self ._exchange_locks .pop (key , None )
196+ return
197+
143198
144199def normalize_runtime_auth_mode (raw : str | None ) -> RuntimeAuthMode :
145200 """Normalize configured SDK runtime auth mode."""
0 commit comments