Skip to content

Commit 31f40bf

Browse files
fix(sdk): bound runtime auth fallback cache
1 parent 0d59630 commit 31f40bf

4 files changed

Lines changed: 180 additions & 20 deletions

File tree

sdks/python/src/agent_control/client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,12 @@ async def _exchange_runtime_token(
371371
)
372372
except httpx.RequestError:
373373
if self._runtime_auth_mode == "auto" and allow_auto_fallback:
374+
_logger.debug(
375+
"Runtime token exchange request failed; falling back to normal request auth "
376+
"for %s/%s.",
377+
target_type,
378+
target_id,
379+
)
374380
self._runtime_token_cache.mark_jwt_unavailable(
375381
server_url=self.base_url,
376382
target_type=target_type,
@@ -384,6 +390,13 @@ async def _exchange_runtime_token(
384390
and allow_auto_fallback
385391
and response.status_code in _AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES
386392
):
393+
_logger.debug(
394+
"Runtime token exchange returned HTTP %s; falling back to normal request auth "
395+
"for %s/%s.",
396+
response.status_code,
397+
target_type,
398+
target_id,
399+
)
387400
self._runtime_token_cache.mark_jwt_unavailable(
388401
server_url=self.base_url,
389402
target_type=target_type,

sdks/python/src/agent_control/runtime_auth.py

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
RuntimeAuthMode = Literal["auto", "none", "api_key", "jwt"]
1313

1414
_TokenKey = tuple[str, str, str]
15-
_LockKey = tuple[str, str, str, int]
15+
_LockKey = tuple[str, str, str, asyncio.AbstractEventLoop]
1616
_DEFAULT_MAX_CACHE_ENTRIES = 256
1717
_DEFAULT_JWT_UNAVAILABLE_TTL_SECONDS = 30
1818

@@ -37,12 +37,21 @@ def is_fresh(self, *, refresh_margin_seconds: int) -> bool:
3737
class RuntimeTokenCache:
3838
"""Thread-safe runtime token cache keyed by server and target."""
3939

40-
def __init__(self, *, max_entries: int = _DEFAULT_MAX_CACHE_ENTRIES) -> None:
40+
def __init__(
41+
self,
42+
*,
43+
max_entries: int = _DEFAULT_MAX_CACHE_ENTRIES,
44+
jwt_unavailable_ttl_seconds: int = _DEFAULT_JWT_UNAVAILABLE_TTL_SECONDS,
45+
) -> None:
4146
if max_entries < 1:
4247
raise ValueError("max_entries must be >= 1.")
48+
if jwt_unavailable_ttl_seconds < 0:
49+
raise ValueError("jwt_unavailable_ttl_seconds must be >= 0.")
4350
self._max_entries = max_entries
51+
self._jwt_unavailable_ttl_seconds = jwt_unavailable_ttl_seconds
4452
self._tokens: dict[_TokenKey, RuntimeToken] = {}
45-
self._jwt_unavailable = False
53+
self._jwt_unavailable_until: datetime | None = None
54+
self._jwt_unavailable_servers: dict[str, datetime] = {}
4655
self._jwt_unavailable_targets: dict[_TokenKey, datetime] = {}
4756
self._exchange_locks: dict[_LockKey, asyncio.Lock] = {}
4857
self._lock = threading.Lock()
@@ -75,6 +84,7 @@ def set(self, token: RuntimeToken) -> None:
7584
self._tokens.pop(oldest_key, None)
7685
self._jwt_unavailable_targets.pop(oldest_key, None)
7786
self._tokens[key] = token
87+
self._jwt_unavailable_servers.pop(token.server_url, None)
7888
self._jwt_unavailable_targets.pop(key, None)
7989

8090
def remove(self, server_url: str, target_type: str, target_id: str) -> None:
@@ -89,13 +99,30 @@ def mark_jwt_unavailable(
8999
target_type: str | None = None,
90100
target_id: str | None = None,
91101
globally: bool = False,
92-
ttl_seconds: int = _DEFAULT_JWT_UNAVAILABLE_TTL_SECONDS,
102+
ttl_seconds: int | None = None,
93103
) -> None:
94104
"""Record that JWT runtime auth should not be attempted."""
105+
ttl = self._jwt_unavailable_ttl_seconds if ttl_seconds is None else ttl_seconds
106+
if ttl < 0:
107+
raise ValueError("ttl_seconds must be >= 0.")
108+
expires_at = datetime.now(UTC) + timedelta(seconds=ttl)
95109
with self._lock:
110+
self._prune_expired_unavailable_markers_locked()
96111
if globally:
97-
self._jwt_unavailable = True
98-
self._tokens.clear()
112+
if server_url is None:
113+
self._jwt_unavailable_until = expires_at
114+
self._tokens.clear()
115+
self._jwt_unavailable_servers.clear()
116+
self._jwt_unavailable_targets.clear()
117+
return
118+
119+
if (
120+
server_url not in self._jwt_unavailable_servers
121+
and len(self._jwt_unavailable_servers) >= self._max_entries
122+
):
123+
self._jwt_unavailable_servers.pop(next(iter(self._jwt_unavailable_servers)))
124+
self._jwt_unavailable_servers[server_url] = expires_at
125+
self._drop_server_entries_locked(server_url)
99126
return
100127
if server_url is not None and target_type is not None and target_id is not None:
101128
key = (server_url, target_type, target_id)
@@ -104,42 +131,70 @@ def mark_jwt_unavailable(
104131
and len(self._jwt_unavailable_targets) >= self._max_entries
105132
):
106133
self._jwt_unavailable_targets.pop(next(iter(self._jwt_unavailable_targets)))
107-
self._jwt_unavailable_targets[key] = datetime.now(UTC) + timedelta(
108-
seconds=ttl_seconds
109-
)
134+
self._jwt_unavailable_targets[key] = expires_at
110135
self._tokens.pop(key, None)
111136

112137
def is_jwt_unavailable(self, server_url: str, target_type: str, target_id: str) -> bool:
113138
"""Return whether JWT exchange is known unavailable for the target."""
114139
key = (server_url, target_type, target_id)
115140
with self._lock:
116-
if self._jwt_unavailable:
141+
self._prune_expired_unavailable_markers_locked()
142+
if self._jwt_unavailable_until is not None:
117143
return True
118-
expires_at = self._jwt_unavailable_targets.get(key)
119-
if expires_at is None:
120-
return False
121-
if expires_at > datetime.now(UTC):
144+
if server_url in self._jwt_unavailable_servers:
122145
return True
123-
self._jwt_unavailable_targets.pop(key, None)
124-
return False
146+
expires_at = self._jwt_unavailable_targets.get(key)
147+
return expires_at is not None
125148

126149
def clear(self) -> None:
127150
"""Clear every cached token and fallback marker."""
128151
with self._lock:
129152
self._tokens.clear()
130-
self._jwt_unavailable = False
153+
self._jwt_unavailable_until = None
154+
self._jwt_unavailable_servers.clear()
131155
self._jwt_unavailable_targets.clear()
132156

133157
def exchange_lock(self, server_url: str, target_type: str, target_id: str) -> asyncio.Lock:
134158
"""Return the async exchange lock for one server and target."""
135-
key = (server_url, target_type, target_id, id(asyncio.get_running_loop()))
159+
key = (server_url, target_type, target_id, asyncio.get_running_loop())
136160
with self._lock:
137161
lock = self._exchange_locks.get(key)
138162
if lock is None:
163+
if len(self._exchange_locks) >= self._max_entries:
164+
self._evict_idle_exchange_lock_locked()
139165
lock = asyncio.Lock()
140166
self._exchange_locks[key] = lock
167+
else:
168+
# Preserve insertion order as a simple LRU for idle-lock eviction.
169+
self._exchange_locks.pop(key)
170+
self._exchange_locks[key] = lock
141171
return lock
142172

173+
def _drop_server_entries_locked(self, server_url: str) -> None:
174+
for key in list(self._tokens):
175+
if key[0] == server_url:
176+
self._tokens.pop(key, None)
177+
for key in list(self._jwt_unavailable_targets):
178+
if key[0] == server_url:
179+
self._jwt_unavailable_targets.pop(key, None)
180+
181+
def _prune_expired_unavailable_markers_locked(self) -> None:
182+
now = datetime.now(UTC)
183+
if self._jwt_unavailable_until is not None and self._jwt_unavailable_until <= now:
184+
self._jwt_unavailable_until = None
185+
for server_url, expires_at in list(self._jwt_unavailable_servers.items()):
186+
if expires_at <= now:
187+
self._jwt_unavailable_servers.pop(server_url, None)
188+
for key, expires_at in list(self._jwt_unavailable_targets.items()):
189+
if expires_at <= now:
190+
self._jwt_unavailable_targets.pop(key, None)
191+
192+
def _evict_idle_exchange_lock_locked(self) -> None:
193+
for key, lock in list(self._exchange_locks.items()):
194+
if not lock.locked():
195+
self._exchange_locks.pop(key, None)
196+
return
197+
143198

144199
def normalize_runtime_auth_mode(raw: str | None) -> RuntimeAuthMode:
145200
"""Normalize configured SDK runtime auth mode."""

sdks/python/tests/test_client.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,57 @@ def handler(request: httpx.Request) -> httpx.Response:
369369
assert evaluation_api_key_headers == ["test-key", "test-key"]
370370

371371

372+
@pytest.mark.asyncio
373+
async def test_runtime_evaluation_auto_404_fallback_recovers_after_ttl() -> None:
374+
exchange_calls = 0
375+
evaluation_authorization_headers: list[str | None] = []
376+
evaluation_api_key_headers: list[str | None] = []
377+
expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat()
378+
cache = RuntimeTokenCache(jwt_unavailable_ttl_seconds=0)
379+
380+
def handler(request: httpx.Request) -> httpx.Response:
381+
nonlocal exchange_calls
382+
if request.url.path == "/api/v1/auth/runtime-token-exchange":
383+
exchange_calls += 1
384+
if exchange_calls == 1:
385+
return httpx.Response(404, json={"detail": "not found"})
386+
return httpx.Response(
387+
200,
388+
json={
389+
"token": "runtime-token",
390+
"expires_at": expires_at,
391+
"target_type": "log_stream",
392+
"target_id": "ls-1",
393+
"scopes": ["runtime.use"],
394+
},
395+
)
396+
397+
evaluation_authorization_headers.append(request.headers.get("Authorization"))
398+
evaluation_api_key_headers.append(request.headers.get("X-API-Key"))
399+
return httpx.Response(200, json={"is_safe": True, "confidence": 1.0})
400+
401+
transport = httpx.MockTransport(handler)
402+
403+
async with AgentControlClient(
404+
base_url="https://agent-control.test",
405+
api_key="test-key",
406+
runtime_auth_mode="auto",
407+
runtime_token_cache=cache,
408+
transport=transport,
409+
) as client:
410+
for _ in range(2):
411+
response = await client.post_runtime_evaluation(
412+
json={"target_type": "log_stream", "target_id": "ls-1"},
413+
target_type="log_stream",
414+
target_id="ls-1",
415+
)
416+
assert response.status_code == 200
417+
418+
assert exchange_calls == 2
419+
assert evaluation_authorization_headers == [None, "Bearer runtime-token"]
420+
assert evaluation_api_key_headers == ["test-key", None]
421+
422+
372423
@pytest.mark.asyncio
373424
async def test_runtime_evaluation_auto_503_fallback_is_target_scoped() -> None:
374425
exchange_targets: list[str] = []

sdks/python/tests/test_runtime_auth.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,41 @@ def test_runtime_token_cache_target_unavailable_marker_expires() -> None:
9898
assert not cache.is_jwt_unavailable("https://server-a.test", "log_stream", "ls-1")
9999

100100

101-
def test_runtime_token_cache_global_unavailable_clears_cache() -> None:
101+
def test_runtime_token_cache_server_unavailable_clears_server_cache() -> None:
102102
cache = RuntimeTokenCache()
103103
cache.set(_runtime_token())
104+
token_2 = _runtime_token(
105+
token="token-2",
106+
server_url="https://server-b.test",
107+
target_id="ls-2",
108+
)
109+
cache.set(token_2)
104110

105-
cache.mark_jwt_unavailable(globally=True)
111+
cache.mark_jwt_unavailable(server_url="https://server-a.test", globally=True)
106112

107113
assert cache.is_jwt_unavailable("https://server-a.test", "log_stream", "ls-1")
114+
assert not cache.is_jwt_unavailable("https://server-b.test", "log_stream", "ls-2")
108115
assert (
109116
cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=0) is None
110117
)
118+
assert (
119+
cache.get("https://server-b.test", "log_stream", "ls-2", refresh_margin_seconds=0)
120+
== token_2
121+
)
111122

112123
cache.clear()
113124

114125
assert not cache.is_jwt_unavailable("https://server-a.test", "log_stream", "ls-1")
115126

116127

128+
def test_runtime_token_cache_server_unavailable_marker_expires() -> None:
129+
cache = RuntimeTokenCache(jwt_unavailable_ttl_seconds=0)
130+
131+
cache.mark_jwt_unavailable(server_url="https://server-a.test", globally=True)
132+
133+
assert not cache.is_jwt_unavailable("https://server-a.test", "log_stream", "ls-1")
134+
135+
117136
def test_runtime_token_cache_remove_drops_one_token() -> None:
118137
cache = RuntimeTokenCache()
119138
cache.set(_runtime_token(target_id="ls-1"))
@@ -159,6 +178,16 @@ async def test_runtime_token_cache_token_eviction_preserves_exchange_lock() -> N
159178
assert cache.exchange_lock("https://server-a.test", "log_stream", "ls-1") is lock
160179

161180

181+
@pytest.mark.asyncio
182+
async def test_runtime_token_cache_evicts_idle_exchange_locks() -> None:
183+
cache = RuntimeTokenCache(max_entries=1)
184+
first = cache.exchange_lock("https://server-a.test", "log_stream", "ls-1")
185+
186+
cache.exchange_lock("https://server-a.test", "log_stream", "ls-2")
187+
188+
assert cache.exchange_lock("https://server-a.test", "log_stream", "ls-1") is not first
189+
190+
162191
def test_runtime_token_cache_exchange_locks_are_loop_scoped(
163192
monkeypatch: pytest.MonkeyPatch,
164193
) -> None:
@@ -180,6 +209,18 @@ def test_runtime_token_cache_rejects_empty_capacity() -> None:
180209
RuntimeTokenCache(max_entries=0)
181210

182211

212+
def test_runtime_token_cache_rejects_negative_unavailable_ttl() -> None:
213+
with pytest.raises(ValueError, match="jwt_unavailable_ttl_seconds"):
214+
RuntimeTokenCache(jwt_unavailable_ttl_seconds=-1)
215+
216+
217+
def test_runtime_token_cache_rejects_negative_marker_ttl() -> None:
218+
cache = RuntimeTokenCache()
219+
220+
with pytest.raises(ValueError, match="ttl_seconds"):
221+
cache.mark_jwt_unavailable(ttl_seconds=-1)
222+
223+
183224
@pytest.mark.parametrize(
184225
("raw", "expected"),
185226
[

0 commit comments

Comments
 (0)