Skip to content

Commit 6655af2

Browse files
feat(sdk-python): add runtime token auth
Exchange target-bound runtime tokens for evaluation requests when configured, cache them per target, and retry once after a 401. Keep no-auth and API-key runtime flows on the existing request-auth path when token exchange is unavailable or disabled.
1 parent 097b42d commit 6655af2

8 files changed

Lines changed: 1260 additions & 21 deletions

File tree

sdks/python/src/agent_control/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ async def handle(message: str):
561561
state.current_agent = next_agent
562562
state.server_url = server_url or os.getenv('AGENT_CONTROL_URL') or 'http://localhost:8000'
563563
state.api_key = api_key
564+
state.runtime_token_cache.clear()
564565
state.target_type = target_type
565566
state.target_id = target_id
566567

@@ -596,7 +597,8 @@ async def register() -> list[dict[str, Any]] | None:
596597
assert state.current_agent is not None
597598

598599
async with AgentControlClient(
599-
base_url=state.server_url, api_key=state.api_key
600+
base_url=state.server_url,
601+
api_key=state.api_key,
600602
) as client:
601603
# Check server health first
602604
try:
@@ -714,6 +716,7 @@ def _reset_state() -> None:
714716
state.server_controls = None
715717
state.server_url = None
716718
state.api_key = None
719+
state.runtime_token_cache.clear()
717720
state.target_type = None
718721
state.target_id = None
719722

sdks/python/src/agent_control/_state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from typing import TYPE_CHECKING, Any
1010

11+
from .runtime_auth import RuntimeTokenCache
12+
1113
if TYPE_CHECKING:
1214
from agent_control_models import Agent
1315

@@ -24,6 +26,7 @@ def __init__(self) -> None:
2426
self.server_controls: list[dict[str, Any]] | None = None
2527
self.server_url: str | None = None
2628
self.api_key: str | None = None
29+
self.runtime_token_cache = RuntimeTokenCache()
2730
# Optional target context fixed at init() time; both fields are set
2831
# together or both remain None.
2932
self.target_type: str | None = None

sdks/python/src/agent_control/client.py

Lines changed: 207 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,43 @@
22

33
import logging
44
import os
5+
from collections.abc import Generator
56
from types import TracebackType
7+
from typing import Any, cast
68

79
import httpx
810

911
from . import __version__ as sdk_version
12+
from .runtime_auth import (
13+
RuntimeAuthMode,
14+
RuntimeTokenCache,
15+
normalize_runtime_auth_mode,
16+
parse_runtime_token_exchange_response,
17+
)
1018

1119
_logger = logging.getLogger(__name__)
1220

21+
_RUNTIME_AUTH_MODE_ENV_VAR = "AGENT_CONTROL_RUNTIME_AUTH_MODE"
22+
_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS = 30
23+
_AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES = {404, 503}
24+
_GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES = {404, 503}
25+
26+
27+
class _AgentControlAuth(httpx.Auth):
28+
"""Attach local API-key credentials unless a request already has Bearer auth."""
29+
30+
def __init__(self, api_key: str | None) -> None:
31+
self._api_key = api_key
32+
33+
def auth_flow(
34+
self,
35+
request: httpx.Request,
36+
) -> Generator[httpx.Request, httpx.Response, None]:
37+
if self._api_key and "Authorization" not in request.headers:
38+
if "X-API-Key" not in request.headers:
39+
request.headers["X-API-Key"] = self._api_key
40+
yield request
41+
1342

1443
class AgentControlClient:
1544
"""
@@ -45,6 +74,10 @@ def __init__(
4574
base_url: str | None = None,
4675
timeout: float = 30.0,
4776
api_key: str | None = None,
77+
runtime_auth_mode: RuntimeAuthMode | str | None = None,
78+
runtime_token_cache: RuntimeTokenCache | None = None,
79+
runtime_token_refresh_margin_seconds: int = (_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS),
80+
transport: httpx.AsyncBaseTransport | None = None,
4881
):
4982
"""
5083
Initialize the client.
@@ -55,13 +88,29 @@ def __init__(
5588
timeout: Request timeout in seconds
5689
api_key: API key for authentication. If not provided, will attempt
5790
to read from AGENT_CONTROL_API_KEY environment variable.
91+
runtime_auth_mode: Runtime auth mode for evaluation requests. ``auto``
92+
attempts target-bound JWT exchange and falls back to normal
93+
request auth when the exchange endpoint is unavailable. ``jwt``
94+
requires a successful exchange. ``api_key`` and ``none`` keep
95+
evaluation requests on the normal request-auth path.
96+
runtime_token_cache: Optional cache shared across client instances.
97+
runtime_token_refresh_margin_seconds: Refresh cached runtime tokens
98+
before this many seconds of validity remain.
99+
transport: Optional httpx transport, primarily for tests.
58100
"""
59101
resolved_base_url = base_url or os.environ.get(
60102
self.BASE_URL_ENV_VAR, "http://localhost:8000"
61103
)
62104
self.base_url = resolved_base_url.rstrip("/")
63105
self.timeout = timeout
64106
self._api_key = api_key or os.environ.get(self.API_KEY_ENV_VAR)
107+
configured_runtime_mode = runtime_auth_mode or os.environ.get(_RUNTIME_AUTH_MODE_ENV_VAR)
108+
self._runtime_auth_mode = normalize_runtime_auth_mode(configured_runtime_mode)
109+
if runtime_token_refresh_margin_seconds < 0:
110+
raise ValueError("runtime_token_refresh_margin_seconds must be >= 0.")
111+
self._runtime_token_refresh_margin_seconds = runtime_token_refresh_margin_seconds
112+
self._runtime_token_cache = runtime_token_cache or RuntimeTokenCache()
113+
self._transport = transport
65114
self._client: httpx.AsyncClient | None = None
66115
self._server_version_warning_emitted = False
67116

@@ -70,15 +119,17 @@ def api_key(self) -> str | None:
70119
"""Get the configured API key (read-only)."""
71120
return self._api_key
72121

122+
@property
123+
def runtime_auth_mode(self) -> RuntimeAuthMode:
124+
"""Get the configured runtime auth mode (read-only)."""
125+
return self._runtime_auth_mode
126+
73127
def _get_headers(self) -> dict[str, str]:
74-
"""Build request headers including authentication."""
75-
headers: dict[str, str] = {
128+
"""Build base SDK metadata headers."""
129+
return {
76130
"X-Agent-Control-SDK": "python",
77131
"X-Agent-Control-SDK-Version": sdk_version,
78132
}
79-
if self._api_key:
80-
headers["X-API-Key"] = self._api_key
81-
return headers
82133

83134
async def _check_server_version(self, response: httpx.Response) -> None:
84135
"""Warn once when the server major version differs from the SDK major."""
@@ -108,6 +159,8 @@ async def __aenter__(self) -> "AgentControlClient":
108159
base_url=self.base_url,
109160
timeout=self.timeout,
110161
headers=self._get_headers(),
162+
auth=_AgentControlAuth(self._api_key),
163+
transport=self._transport,
111164
event_hooks={"response": [self._check_server_version]},
112165
)
113166
return self
@@ -137,6 +190,7 @@ async def health_check(self) -> dict[str, str]:
137190
response = await self._client.get("/health")
138191
response.raise_for_status()
139192
from typing import cast
193+
140194
return cast(dict[str, str], response.json())
141195

142196
@property
@@ -145,3 +199,151 @@ def http_client(self) -> httpx.AsyncClient:
145199
if self._client is None:
146200
raise RuntimeError("Client not initialized. Use 'async with' context manager.")
147201
return self._client
202+
203+
async def post_runtime_evaluation(
204+
self,
205+
*,
206+
json: dict[str, Any],
207+
headers: dict[str, str] | None = None,
208+
target_type: str | None = None,
209+
target_id: str | None = None,
210+
) -> httpx.Response:
211+
"""POST an evaluation request with runtime auth when configured."""
212+
runtime_authorization = await self._runtime_authorization(
213+
target_type=target_type,
214+
target_id=target_id,
215+
)
216+
request_headers = self._merge_runtime_headers(headers, runtime_authorization)
217+
response = await self.http_client.post(
218+
"/api/v1/evaluation",
219+
json=json,
220+
headers=request_headers,
221+
)
222+
223+
if response.status_code == 401 and runtime_authorization is not None:
224+
await response.aread()
225+
runtime_authorization = await self._runtime_authorization(
226+
target_type=target_type,
227+
target_id=target_id,
228+
force_refresh=True,
229+
allow_auto_fallback=False,
230+
)
231+
request_headers = self._merge_runtime_headers(headers, runtime_authorization)
232+
response = await self.http_client.post(
233+
"/api/v1/evaluation",
234+
json=json,
235+
headers=request_headers,
236+
)
237+
238+
return response
239+
240+
def _merge_runtime_headers(
241+
self,
242+
headers: dict[str, str] | None,
243+
runtime_authorization: str | None,
244+
) -> dict[str, str] | None:
245+
"""Merge caller headers with an optional Bearer token."""
246+
if headers is None and runtime_authorization is None:
247+
return None
248+
249+
merged = dict(headers or {})
250+
if runtime_authorization is not None:
251+
merged["Authorization"] = runtime_authorization
252+
return merged
253+
254+
async def _runtime_authorization(
255+
self,
256+
*,
257+
target_type: str | None,
258+
target_id: str | None,
259+
force_refresh: bool = False,
260+
allow_auto_fallback: bool = True,
261+
) -> str | None:
262+
"""Return an Authorization header value for runtime evaluation."""
263+
if self._runtime_auth_mode in {"none", "api_key"}:
264+
return None
265+
266+
if target_type is None or target_id is None:
267+
if self._runtime_auth_mode == "jwt":
268+
raise RuntimeError(
269+
"runtime_auth_mode='jwt' requires target_type and target_id "
270+
"for evaluation requests."
271+
)
272+
return None
273+
274+
if self._runtime_auth_mode == "auto" and self._runtime_token_cache.is_jwt_unavailable(
275+
self.base_url, target_type, target_id
276+
):
277+
return None
278+
279+
if not force_refresh:
280+
cached = self._runtime_token_cache.get(
281+
self.base_url,
282+
target_type,
283+
target_id,
284+
refresh_margin_seconds=self._runtime_token_refresh_margin_seconds,
285+
)
286+
if cached is not None:
287+
return f"Bearer {cached.token}"
288+
289+
exchange_lock = self._runtime_token_cache.exchange_lock(
290+
self.base_url,
291+
target_type,
292+
target_id,
293+
)
294+
async with exchange_lock:
295+
if not force_refresh:
296+
cached = self._runtime_token_cache.get(
297+
self.base_url,
298+
target_type,
299+
target_id,
300+
refresh_margin_seconds=self._runtime_token_refresh_margin_seconds,
301+
)
302+
if cached is not None:
303+
return f"Bearer {cached.token}"
304+
305+
token = await self._exchange_runtime_token(
306+
target_type=target_type,
307+
target_id=target_id,
308+
allow_auto_fallback=allow_auto_fallback,
309+
)
310+
if token is None:
311+
return None
312+
return f"Bearer {token}"
313+
314+
async def _exchange_runtime_token(
315+
self,
316+
*,
317+
target_type: str,
318+
target_id: str,
319+
allow_auto_fallback: bool = True,
320+
) -> str | None:
321+
"""Exchange the configured credential for a target-bound runtime token."""
322+
response = await self.http_client.post(
323+
"/api/v1/auth/runtime-token-exchange",
324+
json={"target_type": target_type, "target_id": target_id},
325+
)
326+
327+
if (
328+
self._runtime_auth_mode == "auto"
329+
and allow_auto_fallback
330+
and response.status_code in _AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES
331+
):
332+
self._runtime_token_cache.mark_jwt_unavailable(
333+
server_url=self.base_url,
334+
target_type=target_type,
335+
target_id=target_id,
336+
globally=response.status_code in _GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES,
337+
)
338+
return None
339+
340+
response.raise_for_status()
341+
payload = response.json()
342+
if not isinstance(payload, dict):
343+
raise RuntimeError("Runtime token exchange response was not an object.")
344+
token = parse_runtime_token_exchange_response(
345+
cast(dict[str, object], payload),
346+
server_url=self.base_url,
347+
)
348+
self._runtime_token_cache.set(token)
349+
return token.token

0 commit comments

Comments
 (0)