Skip to content

Commit 4fffcf2

Browse files
feat(sdk-python): add runtime token auth
Exchange target-bound runtime tokens for evaluation requests when configured, cache them per target, and retry once after a 401. Keep no-auth and API-key runtime flows on the existing request-auth path when token exchange is unavailable or disabled.
1 parent 1e7f59a commit 4fffcf2

7 files changed

Lines changed: 861 additions & 21 deletions

File tree

sdks/python/src/agent_control/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ async def handle(message: str):
561561
state.current_agent = next_agent
562562
state.server_url = server_url or os.getenv('AGENT_CONTROL_URL') or 'http://localhost:8000'
563563
state.api_key = api_key
564+
state.runtime_token_cache.clear()
564565
state.target_type = target_type
565566
state.target_id = target_id
566567

@@ -596,7 +597,8 @@ async def register() -> list[dict[str, Any]] | None:
596597
assert state.current_agent is not None
597598

598599
async with AgentControlClient(
599-
base_url=state.server_url, api_key=state.api_key
600+
base_url=state.server_url,
601+
api_key=state.api_key,
600602
) as client:
601603
# Check server health first
602604
try:
@@ -714,6 +716,7 @@ def _reset_state() -> None:
714716
state.server_controls = None
715717
state.server_url = None
716718
state.api_key = None
719+
state.runtime_token_cache.clear()
717720
state.target_type = None
718721
state.target_id = None
719722

sdks/python/src/agent_control/_state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from typing import TYPE_CHECKING, Any
1010

11+
from .runtime_auth import RuntimeTokenCache
12+
1113
if TYPE_CHECKING:
1214
from agent_control_models import Agent
1315

@@ -24,6 +26,7 @@ def __init__(self) -> None:
2426
self.server_controls: list[dict[str, Any]] | None = None
2527
self.server_url: str | None = None
2628
self.api_key: str | None = None
29+
self.runtime_token_cache = RuntimeTokenCache()
2730
# Optional target context fixed at init() time; both fields are set
2831
# together or both remain None.
2932
self.target_type: str | None = None

sdks/python/src/agent_control/client.py

Lines changed: 186 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,43 @@
22

33
import logging
44
import os
5+
from collections.abc import Generator
56
from types import TracebackType
7+
from typing import Any, cast
68

79
import httpx
810

911
from . import __version__ as sdk_version
12+
from .runtime_auth import (
13+
RuntimeAuthMode,
14+
RuntimeTokenCache,
15+
normalize_runtime_auth_mode,
16+
parse_runtime_token_exchange_response,
17+
)
1018

1119
_logger = logging.getLogger(__name__)
1220

21+
_RUNTIME_AUTH_MODE_ENV_VAR = "AGENT_CONTROL_RUNTIME_AUTH_MODE"
22+
_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS = 30
23+
_AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES = {400, 401, 403, 404, 503}
24+
_GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES = {404, 503}
25+
26+
27+
class _AgentControlAuth(httpx.Auth):
28+
"""Attach local API-key credentials unless a request already has Bearer auth."""
29+
30+
def __init__(self, api_key: str | None) -> None:
31+
self._api_key = api_key
32+
33+
def auth_flow(
34+
self,
35+
request: httpx.Request,
36+
) -> Generator[httpx.Request, httpx.Response, None]:
37+
if self._api_key and "Authorization" not in request.headers:
38+
if "X-API-Key" not in request.headers:
39+
request.headers["X-API-Key"] = self._api_key
40+
yield request
41+
1342

1443
class AgentControlClient:
1544
"""
@@ -45,6 +74,10 @@ def __init__(
4574
base_url: str | None = None,
4675
timeout: float = 30.0,
4776
api_key: str | None = None,
77+
runtime_auth_mode: RuntimeAuthMode | str | None = None,
78+
runtime_token_cache: RuntimeTokenCache | None = None,
79+
runtime_token_refresh_margin_seconds: int = (_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS),
80+
transport: httpx.AsyncBaseTransport | None = None,
4881
):
4982
"""
5083
Initialize the client.
@@ -55,13 +88,29 @@ def __init__(
5588
timeout: Request timeout in seconds
5689
api_key: API key for authentication. If not provided, will attempt
5790
to read from AGENT_CONTROL_API_KEY environment variable.
91+
runtime_auth_mode: Runtime auth mode for evaluation requests. ``auto``
92+
attempts target-bound JWT exchange and falls back to normal
93+
request auth when the exchange endpoint is unavailable. ``jwt``
94+
requires a successful exchange. ``api_key`` and ``none`` keep
95+
evaluation requests on the normal request-auth path.
96+
runtime_token_cache: Optional cache shared across client instances.
97+
runtime_token_refresh_margin_seconds: Refresh cached runtime tokens
98+
before this many seconds of validity remain.
99+
transport: Optional httpx transport, primarily for tests.
58100
"""
59101
resolved_base_url = base_url or os.environ.get(
60102
self.BASE_URL_ENV_VAR, "http://localhost:8000"
61103
)
62104
self.base_url = resolved_base_url.rstrip("/")
63105
self.timeout = timeout
64106
self._api_key = api_key or os.environ.get(self.API_KEY_ENV_VAR)
107+
configured_runtime_mode = runtime_auth_mode or os.environ.get(_RUNTIME_AUTH_MODE_ENV_VAR)
108+
self._runtime_auth_mode = normalize_runtime_auth_mode(configured_runtime_mode)
109+
if runtime_token_refresh_margin_seconds < 0:
110+
raise ValueError("runtime_token_refresh_margin_seconds must be >= 0.")
111+
self._runtime_token_refresh_margin_seconds = runtime_token_refresh_margin_seconds
112+
self._runtime_token_cache = runtime_token_cache or RuntimeTokenCache()
113+
self._transport = transport
65114
self._client: httpx.AsyncClient | None = None
66115
self._server_version_warning_emitted = False
67116

@@ -70,15 +119,17 @@ def api_key(self) -> str | None:
70119
"""Get the configured API key (read-only)."""
71120
return self._api_key
72121

122+
@property
123+
def runtime_auth_mode(self) -> RuntimeAuthMode:
124+
"""Get the configured runtime auth mode (read-only)."""
125+
return self._runtime_auth_mode
126+
73127
def _get_headers(self) -> dict[str, str]:
74-
"""Build request headers including authentication."""
75-
headers: dict[str, str] = {
128+
"""Build base SDK metadata headers."""
129+
return {
76130
"X-Agent-Control-SDK": "python",
77131
"X-Agent-Control-SDK-Version": sdk_version,
78132
}
79-
if self._api_key:
80-
headers["X-API-Key"] = self._api_key
81-
return headers
82133

83134
async def _check_server_version(self, response: httpx.Response) -> None:
84135
"""Warn once when the server major version differs from the SDK major."""
@@ -108,6 +159,8 @@ async def __aenter__(self) -> "AgentControlClient":
108159
base_url=self.base_url,
109160
timeout=self.timeout,
110161
headers=self._get_headers(),
162+
auth=_AgentControlAuth(self._api_key),
163+
transport=self._transport,
111164
event_hooks={"response": [self._check_server_version]},
112165
)
113166
return self
@@ -137,6 +190,7 @@ async def health_check(self) -> dict[str, str]:
137190
response = await self._client.get("/health")
138191
response.raise_for_status()
139192
from typing import cast
193+
140194
return cast(dict[str, str], response.json())
141195

142196
@property
@@ -145,3 +199,130 @@ def http_client(self) -> httpx.AsyncClient:
145199
if self._client is None:
146200
raise RuntimeError("Client not initialized. Use 'async with' context manager.")
147201
return self._client
202+
203+
async def post_runtime_evaluation(
204+
self,
205+
*,
206+
json: dict[str, Any],
207+
headers: dict[str, str] | None = None,
208+
target_type: str | None = None,
209+
target_id: str | None = None,
210+
) -> httpx.Response:
211+
"""POST an evaluation request with runtime auth when configured."""
212+
runtime_authorization = await self._runtime_authorization(
213+
target_type=target_type,
214+
target_id=target_id,
215+
)
216+
request_headers = self._merge_runtime_headers(headers, runtime_authorization)
217+
response = await self.http_client.post(
218+
"/api/v1/evaluation",
219+
json=json,
220+
headers=request_headers,
221+
)
222+
223+
if response.status_code == 401 and runtime_authorization is not None:
224+
await response.aread()
225+
runtime_authorization = await self._runtime_authorization(
226+
target_type=target_type,
227+
target_id=target_id,
228+
force_refresh=True,
229+
)
230+
request_headers = self._merge_runtime_headers(headers, runtime_authorization)
231+
response = await self.http_client.post(
232+
"/api/v1/evaluation",
233+
json=json,
234+
headers=request_headers,
235+
)
236+
237+
return response
238+
239+
def _merge_runtime_headers(
240+
self,
241+
headers: dict[str, str] | None,
242+
runtime_authorization: str | None,
243+
) -> dict[str, str] | None:
244+
"""Merge caller headers with an optional Bearer token."""
245+
if headers is None and runtime_authorization is None:
246+
return None
247+
248+
merged = dict(headers or {})
249+
if runtime_authorization is not None:
250+
merged["Authorization"] = runtime_authorization
251+
return merged
252+
253+
async def _runtime_authorization(
254+
self,
255+
*,
256+
target_type: str | None,
257+
target_id: str | None,
258+
force_refresh: bool = False,
259+
) -> str | None:
260+
"""Return an Authorization header value for runtime evaluation."""
261+
if self._runtime_auth_mode in {"none", "api_key"}:
262+
return None
263+
264+
if target_type is None or target_id is None:
265+
if self._runtime_auth_mode == "jwt":
266+
raise RuntimeError(
267+
"runtime_auth_mode='jwt' requires target_type and target_id "
268+
"for evaluation requests."
269+
)
270+
return None
271+
272+
if self._runtime_auth_mode == "auto" and self._runtime_token_cache.is_jwt_unavailable(
273+
self.base_url, target_type, target_id
274+
):
275+
return None
276+
277+
if not force_refresh:
278+
cached = self._runtime_token_cache.get(
279+
self.base_url,
280+
target_type,
281+
target_id,
282+
refresh_margin_seconds=self._runtime_token_refresh_margin_seconds,
283+
)
284+
if cached is not None:
285+
return f"Bearer {cached.token}"
286+
287+
token = await self._exchange_runtime_token(
288+
target_type=target_type,
289+
target_id=target_id,
290+
)
291+
if token is None:
292+
return None
293+
return f"Bearer {token}"
294+
295+
async def _exchange_runtime_token(
296+
self,
297+
*,
298+
target_type: str,
299+
target_id: str,
300+
) -> str | None:
301+
"""Exchange the configured credential for a target-bound runtime token."""
302+
response = await self.http_client.post(
303+
"/api/v1/auth/runtime-token-exchange",
304+
json={"target_type": target_type, "target_id": target_id},
305+
)
306+
307+
if (
308+
self._runtime_auth_mode == "auto"
309+
and response.status_code in _AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES
310+
):
311+
self._runtime_token_cache.mark_jwt_unavailable(
312+
server_url=self.base_url,
313+
target_type=target_type,
314+
target_id=target_id,
315+
globally=response.status_code in _GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES,
316+
)
317+
return None
318+
319+
response.raise_for_status()
320+
payload = response.json()
321+
if not isinstance(payload, dict):
322+
raise RuntimeError("Runtime token exchange response was not an object.")
323+
token = parse_runtime_token_exchange_response(
324+
cast(dict[str, object], payload),
325+
server_url=self.base_url,
326+
)
327+
self._runtime_token_cache.set(token)
328+
return token.token

0 commit comments

Comments
 (0)