55import asyncio
66import threading
77from collections .abc import Mapping , Sequence
8- from dataclasses import dataclass
8+ from dataclasses import dataclass , field
99from datetime import UTC , datetime , timedelta
1010from typing import Literal
1111
1212RuntimeAuthMode = Literal ["auto" , "none" , "api_key" , "jwt" ]
1313
1414_TokenKey = tuple [str , str , str ]
15+ _LockKey = tuple [str , str , str , int ]
1516_DEFAULT_MAX_CACHE_ENTRIES = 256
17+ _DEFAULT_JWT_UNAVAILABLE_TTL_SECONDS = 30
1618
1719
1820@dataclass (frozen = True )
1921class RuntimeToken :
2022 """Short-lived runtime token bound to one target."""
2123
22- token : str
24+ token : str = field ( repr = False )
2325 expires_at : datetime
2426 server_url : str
2527 target_type : str
@@ -41,8 +43,8 @@ def __init__(self, *, max_entries: int = _DEFAULT_MAX_CACHE_ENTRIES) -> None:
4143 self ._max_entries = max_entries
4244 self ._tokens : dict [_TokenKey , RuntimeToken ] = {}
4345 self ._jwt_unavailable = False
44- self ._jwt_unavailable_targets : set [_TokenKey ] = set ()
45- self ._exchange_locks : dict [_TokenKey , asyncio .Lock ] = {}
46+ self ._jwt_unavailable_targets : dict [_TokenKey , datetime ] = {}
47+ self ._exchange_locks : dict [_LockKey , asyncio .Lock ] = {}
4648 self ._lock = threading .Lock ()
4749
4850 def get (
@@ -71,10 +73,9 @@ def set(self, token: RuntimeToken) -> None:
7173 if key not in self ._tokens and len (self ._tokens ) >= self ._max_entries :
7274 oldest_key = next (iter (self ._tokens ))
7375 self ._tokens .pop (oldest_key , None )
74- self ._jwt_unavailable_targets .discard (oldest_key )
75- self ._exchange_locks .pop (oldest_key , None )
76+ self ._jwt_unavailable_targets .pop (oldest_key , None )
7677 self ._tokens [key ] = token
77- self ._jwt_unavailable_targets .discard (key )
78+ self ._jwt_unavailable_targets .pop (key , None )
7879
7980 def remove (self , server_url : str , target_type : str , target_id : str ) -> None :
8081 """Drop the cached token for one target."""
@@ -88,6 +89,7 @@ def mark_jwt_unavailable(
8889 target_type : str | None = None ,
8990 target_id : str | None = None ,
9091 globally : bool = False ,
92+ ttl_seconds : int = _DEFAULT_JWT_UNAVAILABLE_TTL_SECONDS ,
9193 ) -> None :
9294 """Record that JWT runtime auth should not be attempted."""
9395 with self ._lock :
@@ -101,28 +103,36 @@ def mark_jwt_unavailable(
101103 key not in self ._jwt_unavailable_targets
102104 and len (self ._jwt_unavailable_targets ) >= self ._max_entries
103105 ):
104- evicted_key = self ._jwt_unavailable_targets .pop ()
105- self ._exchange_locks .pop (evicted_key , None )
106- self ._jwt_unavailable_targets .add (key )
106+ self ._jwt_unavailable_targets .pop (next (iter (self ._jwt_unavailable_targets )))
107+ self ._jwt_unavailable_targets [key ] = datetime .now (UTC ) + timedelta (
108+ seconds = ttl_seconds
109+ )
107110 self ._tokens .pop (key , None )
108111
109112 def is_jwt_unavailable (self , server_url : str , target_type : str , target_id : str ) -> bool :
110113 """Return whether JWT exchange is known unavailable for the target."""
111114 key = (server_url , target_type , target_id )
112115 with self ._lock :
113- return self ._jwt_unavailable or key in self ._jwt_unavailable_targets
116+ if self ._jwt_unavailable :
117+ return True
118+ expires_at = self ._jwt_unavailable_targets .get (key )
119+ if expires_at is None :
120+ return False
121+ if expires_at > datetime .now (UTC ):
122+ return True
123+ self ._jwt_unavailable_targets .pop (key , None )
124+ return False
114125
115126 def clear (self ) -> None :
116127 """Clear every cached token and fallback marker."""
117128 with self ._lock :
118129 self ._tokens .clear ()
119130 self ._jwt_unavailable = False
120131 self ._jwt_unavailable_targets .clear ()
121- self ._exchange_locks .clear ()
122132
123133 def exchange_lock (self , server_url : str , target_type : str , target_id : str ) -> asyncio .Lock :
124134 """Return the async exchange lock for one server and target."""
125- key = (server_url , target_type , target_id )
135+ key = (server_url , target_type , target_id , id ( asyncio . get_running_loop ()) )
126136 with self ._lock :
127137 lock = self ._exchange_locks .get (key )
128138 if lock is None :
0 commit comments