22
33import logging
44import os
5+ from collections .abc import Generator
56from types import TracebackType
7+ from typing import Any , cast
68
79import httpx
810
911from . import __version__ as sdk_version
12+ from .runtime_auth import (
13+ RuntimeAuthMode ,
14+ RuntimeTokenCache ,
15+ normalize_runtime_auth_mode ,
16+ parse_runtime_token_exchange_response ,
17+ )
1018
1119_logger = logging .getLogger (__name__ )
1220
21+ _RUNTIME_AUTH_MODE_ENV_VAR = "AGENT_CONTROL_RUNTIME_AUTH_MODE"
22+ _DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS = 30
23+ _AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES = {404 , 503 }
24+ _GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES = {404 , 503 }
25+
26+
27+ class _AgentControlAuth (httpx .Auth ):
28+ """Attach local API-key credentials unless a request already has Bearer auth."""
29+
30+ def __init__ (self , api_key : str | None ) -> None :
31+ self ._api_key = api_key
32+
33+ def auth_flow (
34+ self ,
35+ request : httpx .Request ,
36+ ) -> Generator [httpx .Request , httpx .Response , None ]:
37+ if self ._api_key and "Authorization" not in request .headers :
38+ if "X-API-Key" not in request .headers :
39+ request .headers ["X-API-Key" ] = self ._api_key
40+ yield request
41+
1342
1443class AgentControlClient :
1544 """
@@ -45,6 +74,10 @@ def __init__(
4574 base_url : str | None = None ,
4675 timeout : float = 30.0 ,
4776 api_key : str | None = None ,
77+ runtime_auth_mode : RuntimeAuthMode | str | None = None ,
78+ runtime_token_cache : RuntimeTokenCache | None = None ,
79+ runtime_token_refresh_margin_seconds : int = (_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS ),
80+ transport : httpx .AsyncBaseTransport | None = None ,
4881 ):
4982 """
5083 Initialize the client.
@@ -55,13 +88,29 @@ def __init__(
5588 timeout: Request timeout in seconds
5689 api_key: API key for authentication. If not provided, will attempt
5790 to read from AGENT_CONTROL_API_KEY environment variable.
91+ runtime_auth_mode: Runtime auth mode for evaluation requests. ``auto``
92+ attempts target-bound JWT exchange and falls back to normal
93+ request auth when the exchange endpoint is unavailable. ``jwt``
94+ requires a successful exchange. ``api_key`` and ``none`` keep
95+ evaluation requests on the normal request-auth path.
96+ runtime_token_cache: Optional cache shared across client instances.
97+ runtime_token_refresh_margin_seconds: Refresh cached runtime tokens
98+ before this many seconds of validity remain.
99+ transport: Optional httpx transport, primarily for tests.
58100 """
59101 resolved_base_url = base_url or os .environ .get (
60102 self .BASE_URL_ENV_VAR , "http://localhost:8000"
61103 )
62104 self .base_url = resolved_base_url .rstrip ("/" )
63105 self .timeout = timeout
64106 self ._api_key = api_key or os .environ .get (self .API_KEY_ENV_VAR )
107+ configured_runtime_mode = runtime_auth_mode or os .environ .get (_RUNTIME_AUTH_MODE_ENV_VAR )
108+ self ._runtime_auth_mode = normalize_runtime_auth_mode (configured_runtime_mode )
109+ if runtime_token_refresh_margin_seconds < 0 :
110+ raise ValueError ("runtime_token_refresh_margin_seconds must be >= 0." )
111+ self ._runtime_token_refresh_margin_seconds = runtime_token_refresh_margin_seconds
112+ self ._runtime_token_cache = runtime_token_cache or RuntimeTokenCache ()
113+ self ._transport = transport
65114 self ._client : httpx .AsyncClient | None = None
66115 self ._server_version_warning_emitted = False
67116
@@ -70,15 +119,17 @@ def api_key(self) -> str | None:
70119 """Get the configured API key (read-only)."""
71120 return self ._api_key
72121
122+ @property
123+ def runtime_auth_mode (self ) -> RuntimeAuthMode :
124+ """Get the configured runtime auth mode (read-only)."""
125+ return self ._runtime_auth_mode
126+
73127 def _get_headers (self ) -> dict [str , str ]:
74- """Build request headers including authentication ."""
75- headers : dict [ str , str ] = {
128+ """Build base SDK metadata headers ."""
129+ return {
76130 "X-Agent-Control-SDK" : "python" ,
77131 "X-Agent-Control-SDK-Version" : sdk_version ,
78132 }
79- if self ._api_key :
80- headers ["X-API-Key" ] = self ._api_key
81- return headers
82133
83134 async def _check_server_version (self , response : httpx .Response ) -> None :
84135 """Warn once when the server major version differs from the SDK major."""
@@ -108,6 +159,8 @@ async def __aenter__(self) -> "AgentControlClient":
108159 base_url = self .base_url ,
109160 timeout = self .timeout ,
110161 headers = self ._get_headers (),
162+ auth = _AgentControlAuth (self ._api_key ),
163+ transport = self ._transport ,
111164 event_hooks = {"response" : [self ._check_server_version ]},
112165 )
113166 return self
@@ -137,6 +190,7 @@ async def health_check(self) -> dict[str, str]:
137190 response = await self ._client .get ("/health" )
138191 response .raise_for_status ()
139192 from typing import cast
193+
140194 return cast (dict [str , str ], response .json ())
141195
142196 @property
@@ -145,3 +199,151 @@ def http_client(self) -> httpx.AsyncClient:
145199 if self ._client is None :
146200 raise RuntimeError ("Client not initialized. Use 'async with' context manager." )
147201 return self ._client
202+
203+ async def post_runtime_evaluation (
204+ self ,
205+ * ,
206+ json : dict [str , Any ],
207+ headers : dict [str , str ] | None = None ,
208+ target_type : str | None = None ,
209+ target_id : str | None = None ,
210+ ) -> httpx .Response :
211+ """POST an evaluation request with runtime auth when configured."""
212+ runtime_authorization = await self ._runtime_authorization (
213+ target_type = target_type ,
214+ target_id = target_id ,
215+ )
216+ request_headers = self ._merge_runtime_headers (headers , runtime_authorization )
217+ response = await self .http_client .post (
218+ "/api/v1/evaluation" ,
219+ json = json ,
220+ headers = request_headers ,
221+ )
222+
223+ if response .status_code == 401 and runtime_authorization is not None :
224+ await response .aread ()
225+ runtime_authorization = await self ._runtime_authorization (
226+ target_type = target_type ,
227+ target_id = target_id ,
228+ force_refresh = True ,
229+ allow_auto_fallback = False ,
230+ )
231+ request_headers = self ._merge_runtime_headers (headers , runtime_authorization )
232+ response = await self .http_client .post (
233+ "/api/v1/evaluation" ,
234+ json = json ,
235+ headers = request_headers ,
236+ )
237+
238+ return response
239+
240+ def _merge_runtime_headers (
241+ self ,
242+ headers : dict [str , str ] | None ,
243+ runtime_authorization : str | None ,
244+ ) -> dict [str , str ] | None :
245+ """Merge caller headers with an optional Bearer token."""
246+ if headers is None and runtime_authorization is None :
247+ return None
248+
249+ merged = dict (headers or {})
250+ if runtime_authorization is not None :
251+ merged ["Authorization" ] = runtime_authorization
252+ return merged
253+
254+ async def _runtime_authorization (
255+ self ,
256+ * ,
257+ target_type : str | None ,
258+ target_id : str | None ,
259+ force_refresh : bool = False ,
260+ allow_auto_fallback : bool = True ,
261+ ) -> str | None :
262+ """Return an Authorization header value for runtime evaluation."""
263+ if self ._runtime_auth_mode in {"none" , "api_key" }:
264+ return None
265+
266+ if target_type is None or target_id is None :
267+ if self ._runtime_auth_mode == "jwt" :
268+ raise RuntimeError (
269+ "runtime_auth_mode='jwt' requires target_type and target_id "
270+ "for evaluation requests."
271+ )
272+ return None
273+
274+ if self ._runtime_auth_mode == "auto" and self ._runtime_token_cache .is_jwt_unavailable (
275+ self .base_url , target_type , target_id
276+ ):
277+ return None
278+
279+ if not force_refresh :
280+ cached = self ._runtime_token_cache .get (
281+ self .base_url ,
282+ target_type ,
283+ target_id ,
284+ refresh_margin_seconds = self ._runtime_token_refresh_margin_seconds ,
285+ )
286+ if cached is not None :
287+ return f"Bearer { cached .token } "
288+
289+ exchange_lock = self ._runtime_token_cache .exchange_lock (
290+ self .base_url ,
291+ target_type ,
292+ target_id ,
293+ )
294+ async with exchange_lock :
295+ if not force_refresh :
296+ cached = self ._runtime_token_cache .get (
297+ self .base_url ,
298+ target_type ,
299+ target_id ,
300+ refresh_margin_seconds = self ._runtime_token_refresh_margin_seconds ,
301+ )
302+ if cached is not None :
303+ return f"Bearer { cached .token } "
304+
305+ token = await self ._exchange_runtime_token (
306+ target_type = target_type ,
307+ target_id = target_id ,
308+ allow_auto_fallback = allow_auto_fallback ,
309+ )
310+ if token is None :
311+ return None
312+ return f"Bearer { token } "
313+
314+ async def _exchange_runtime_token (
315+ self ,
316+ * ,
317+ target_type : str ,
318+ target_id : str ,
319+ allow_auto_fallback : bool = True ,
320+ ) -> str | None :
321+ """Exchange the configured credential for a target-bound runtime token."""
322+ response = await self .http_client .post (
323+ "/api/v1/auth/runtime-token-exchange" ,
324+ json = {"target_type" : target_type , "target_id" : target_id },
325+ )
326+
327+ if (
328+ self ._runtime_auth_mode == "auto"
329+ and allow_auto_fallback
330+ and response .status_code in _AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES
331+ ):
332+ self ._runtime_token_cache .mark_jwt_unavailable (
333+ server_url = self .base_url ,
334+ target_type = target_type ,
335+ target_id = target_id ,
336+ globally = response .status_code in _GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES ,
337+ )
338+ return None
339+
340+ response .raise_for_status ()
341+ payload = response .json ()
342+ if not isinstance (payload , dict ):
343+ raise RuntimeError ("Runtime token exchange response was not an object." )
344+ token = parse_runtime_token_exchange_response (
345+ cast (dict [str , object ], payload ),
346+ server_url = self .base_url ,
347+ )
348+ self ._runtime_token_cache .set (token )
349+ return token .token
0 commit comments