Skip to content

Commit 26770a3

Browse files
fix(sdk): harden runtime token auth
1 parent 95efaf2 commit 26770a3

10 files changed

Lines changed: 537 additions & 32 deletions

sdks/python/src/agent_control/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,11 @@ async def handle(message: str):
542542
raise ValueError(
543543
"target_type and target_id must be supplied together."
544544
)
545+
resolved_api_key_header = (
546+
api_key_header
547+
or os.getenv(AgentControlClient.API_KEY_HEADER_ENV_VAR)
548+
or AgentControlClient.DEFAULT_API_KEY_HEADER
549+
)
545550

546551
# Re-init behavior: always stop the existing refresh loop before mutating
547552
# shared agent/session globals.
@@ -569,7 +574,7 @@ async def handle(message: str):
569574
state.current_agent = next_agent
570575
state.server_url = server_url or os.getenv('AGENT_CONTROL_URL') or 'http://localhost:8000'
571576
state.api_key = api_key
572-
state.api_key_header = api_key_header
577+
state.api_key_header = resolved_api_key_header
573578
state.runtime_token_cache.clear()
574579
state.target_type = target_type
575580
state.target_id = target_id

sdks/python/src/agent_control/client.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
_RUNTIME_AUTH_MODE_ENV_VAR = "AGENT_CONTROL_RUNTIME_AUTH_MODE"
2222
_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS = 30
23-
_AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES = {404, 503}
24-
_GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES = {404, 503}
23+
_AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES = {404, 500, 502, 503, 504}
24+
_GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES = {404}
2525

2626

2727
class _AgentControlAuth(httpx.Auth):
@@ -248,8 +248,10 @@ async def post_runtime_evaluation(
248248
headers=request_headers,
249249
)
250250

251-
if response.status_code == 401 and runtime_authorization is not None:
251+
if _should_refresh_runtime_token(response) and runtime_authorization is not None:
252252
await response.aread()
253+
if target_type is not None and target_id is not None:
254+
self._runtime_token_cache.remove(self.base_url, target_type, target_id)
253255
runtime_authorization = await self._runtime_authorization(
254256
target_type=target_type,
255257
target_id=target_id,
@@ -262,6 +264,13 @@ async def post_runtime_evaluation(
262264
json=json,
263265
headers=request_headers,
264266
)
267+
if (
268+
_should_refresh_runtime_token(response)
269+
and target_type is not None
270+
and target_id is not None
271+
):
272+
await response.aread()
273+
self._runtime_token_cache.remove(self.base_url, target_type, target_id)
265274

266275
return response
267276

@@ -320,6 +329,14 @@ async def _runtime_authorization(
320329
target_id,
321330
)
322331
async with exchange_lock:
332+
if (
333+
self._runtime_auth_mode == "auto"
334+
and not force_refresh
335+
and self._runtime_token_cache.is_jwt_unavailable(
336+
self.base_url, target_type, target_id
337+
)
338+
):
339+
return None
323340
if not force_refresh:
324341
cached = self._runtime_token_cache.get(
325342
self.base_url,
@@ -347,10 +364,20 @@ async def _exchange_runtime_token(
347364
allow_auto_fallback: bool = True,
348365
) -> str | None:
349366
"""Exchange the configured credential for a target-bound runtime token."""
350-
response = await self.http_client.post(
351-
"/api/v1/auth/runtime-token-exchange",
352-
json={"target_type": target_type, "target_id": target_id},
353-
)
367+
try:
368+
response = await self.http_client.post(
369+
"/api/v1/auth/runtime-token-exchange",
370+
json={"target_type": target_type, "target_id": target_id},
371+
)
372+
except httpx.RequestError:
373+
if self._runtime_auth_mode == "auto" and allow_auto_fallback:
374+
self._runtime_token_cache.mark_jwt_unavailable(
375+
server_url=self.base_url,
376+
target_type=target_type,
377+
target_id=target_id,
378+
)
379+
return None
380+
raise
354381

355382
if (
356383
self._runtime_auth_mode == "auto"
@@ -373,5 +400,18 @@ async def _exchange_runtime_token(
373400
cast(dict[str, object], payload),
374401
server_url=self.base_url,
375402
)
403+
if token.target_type != target_type or token.target_id != target_id:
404+
raise RuntimeError(
405+
"Runtime token exchange response target did not match the requested target."
406+
)
376407
self._runtime_token_cache.set(token)
377408
return token.token
409+
410+
411+
def _should_refresh_runtime_token(response: httpx.Response) -> bool:
412+
if response.status_code == 401:
413+
return True
414+
if response.status_code != 403:
415+
return False
416+
authenticate = response.headers.get("WWW-Authenticate", "")
417+
return "invalid_token" in authenticate.lower()

sdks/python/src/agent_control/control_decorators.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ async def chat(message: str) -> str:
4040

4141
from agent_control import AgentControlClient
4242
from agent_control._state import state
43-
from agent_control.evaluation import _resolve_session_target, check_evaluation_with_local
43+
from agent_control.evaluation import (
44+
_post_evaluation_request,
45+
_resolve_session_target,
46+
check_evaluation_with_local,
47+
)
4448
from agent_control.observability import (
4549
get_logger,
4650
log_control_evaluation,
@@ -388,10 +392,12 @@ async def _evaluate(
388392
payload["target_type"] = target_type
389393
payload["target_id"] = target_id
390394

391-
response = await client.http_client.post(
392-
"/api/v1/evaluation",
393-
json=payload,
395+
response = await _post_evaluation_request(
396+
client,
397+
request_payload=payload,
394398
headers=headers,
399+
target_type=target_type,
400+
target_id=target_id,
395401
)
396402
response.raise_for_status()
397403
result_dict: dict[str, Any] = response.json()

sdks/python/src/agent_control/evaluation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ async def _post_evaluation_request(
229229
) -> httpx.Response:
230230
"""Send an evaluation request, using runtime auth when the client supports it."""
231231
runtime_post = None
232-
if (target_type is not None and target_id is not None) or client.runtime_auth_mode == "jwt":
232+
if (target_type is not None and target_id is not None) or getattr(
233+
client, "runtime_auth_mode", "auto"
234+
) == "jwt":
233235
runtime_post = _runtime_post_evaluation(client)
234236
if runtime_post is not None:
235237
return await runtime_post(

sdks/python/src/agent_control/runtime_auth.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,23 @@
55
import asyncio
66
import threading
77
from collections.abc import Mapping, Sequence
8-
from dataclasses import dataclass
8+
from dataclasses import dataclass, field
99
from datetime import UTC, datetime, timedelta
1010
from typing import Literal
1111

1212
RuntimeAuthMode = Literal["auto", "none", "api_key", "jwt"]
1313

1414
_TokenKey = tuple[str, str, str]
15+
_LockKey = tuple[str, str, str, int]
1516
_DEFAULT_MAX_CACHE_ENTRIES = 256
17+
_DEFAULT_JWT_UNAVAILABLE_TTL_SECONDS = 30
1618

1719

1820
@dataclass(frozen=True)
1921
class RuntimeToken:
2022
"""Short-lived runtime token bound to one target."""
2123

22-
token: str
24+
token: str = field(repr=False)
2325
expires_at: datetime
2426
server_url: str
2527
target_type: str
@@ -41,8 +43,8 @@ def __init__(self, *, max_entries: int = _DEFAULT_MAX_CACHE_ENTRIES) -> None:
4143
self._max_entries = max_entries
4244
self._tokens: dict[_TokenKey, RuntimeToken] = {}
4345
self._jwt_unavailable = False
44-
self._jwt_unavailable_targets: set[_TokenKey] = set()
45-
self._exchange_locks: dict[_TokenKey, asyncio.Lock] = {}
46+
self._jwt_unavailable_targets: dict[_TokenKey, datetime] = {}
47+
self._exchange_locks: dict[_LockKey, asyncio.Lock] = {}
4648
self._lock = threading.Lock()
4749

4850
def get(
@@ -71,10 +73,9 @@ def set(self, token: RuntimeToken) -> None:
7173
if key not in self._tokens and len(self._tokens) >= self._max_entries:
7274
oldest_key = next(iter(self._tokens))
7375
self._tokens.pop(oldest_key, None)
74-
self._jwt_unavailable_targets.discard(oldest_key)
75-
self._exchange_locks.pop(oldest_key, None)
76+
self._jwt_unavailable_targets.pop(oldest_key, None)
7677
self._tokens[key] = token
77-
self._jwt_unavailable_targets.discard(key)
78+
self._jwt_unavailable_targets.pop(key, None)
7879

7980
def remove(self, server_url: str, target_type: str, target_id: str) -> None:
8081
"""Drop the cached token for one target."""
@@ -88,6 +89,7 @@ def mark_jwt_unavailable(
8889
target_type: str | None = None,
8990
target_id: str | None = None,
9091
globally: bool = False,
92+
ttl_seconds: int = _DEFAULT_JWT_UNAVAILABLE_TTL_SECONDS,
9193
) -> None:
9294
"""Record that JWT runtime auth should not be attempted."""
9395
with self._lock:
@@ -101,28 +103,36 @@ def mark_jwt_unavailable(
101103
key not in self._jwt_unavailable_targets
102104
and len(self._jwt_unavailable_targets) >= self._max_entries
103105
):
104-
evicted_key = self._jwt_unavailable_targets.pop()
105-
self._exchange_locks.pop(evicted_key, None)
106-
self._jwt_unavailable_targets.add(key)
106+
self._jwt_unavailable_targets.pop(next(iter(self._jwt_unavailable_targets)))
107+
self._jwt_unavailable_targets[key] = datetime.now(UTC) + timedelta(
108+
seconds=ttl_seconds
109+
)
107110
self._tokens.pop(key, None)
108111

109112
def is_jwt_unavailable(self, server_url: str, target_type: str, target_id: str) -> bool:
110113
"""Return whether JWT exchange is known unavailable for the target."""
111114
key = (server_url, target_type, target_id)
112115
with self._lock:
113-
return self._jwt_unavailable or key in self._jwt_unavailable_targets
116+
if self._jwt_unavailable:
117+
return True
118+
expires_at = self._jwt_unavailable_targets.get(key)
119+
if expires_at is None:
120+
return False
121+
if expires_at > datetime.now(UTC):
122+
return True
123+
self._jwt_unavailable_targets.pop(key, None)
124+
return False
114125

115126
def clear(self) -> None:
116127
"""Clear every cached token and fallback marker."""
117128
with self._lock:
118129
self._tokens.clear()
119130
self._jwt_unavailable = False
120131
self._jwt_unavailable_targets.clear()
121-
self._exchange_locks.clear()
122132

123133
def exchange_lock(self, server_url: str, target_type: str, target_id: str) -> asyncio.Lock:
124134
"""Return the async exchange lock for one server and target."""
125-
key = (server_url, target_type, target_id)
135+
key = (server_url, target_type, target_id, id(asyncio.get_running_loop()))
126136
with self._lock:
127137
lock = self._exchange_locks.get(key)
128138
if lock is None:

0 commit comments

Comments
 (0)