22
33import logging
44import os
5+ from collections .abc import Generator
56from types import TracebackType
7+ from typing import Any , cast
68
79import httpx
810
911from . import __version__ as sdk_version
12+ from .runtime_auth import (
13+ RuntimeAuthMode ,
14+ RuntimeTokenCache ,
15+ normalize_runtime_auth_mode ,
16+ parse_runtime_token_exchange_response ,
17+ )
1018
1119_logger = logging .getLogger (__name__ )
1220
21+ _RUNTIME_AUTH_MODE_ENV_VAR = "AGENT_CONTROL_RUNTIME_AUTH_MODE"
22+ _DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS = 30
23+ _AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES = {400 , 401 , 403 , 404 , 503 }
24+ _GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES = {404 , 503 }
25+
26+
27+ class _AgentControlAuth (httpx .Auth ):
28+ """Attach local API-key credentials unless a request already has Bearer auth."""
29+
30+ def __init__ (self , api_key : str | None ) -> None :
31+ self ._api_key = api_key
32+
33+ def auth_flow (
34+ self ,
35+ request : httpx .Request ,
36+ ) -> Generator [httpx .Request , httpx .Response , None ]:
37+ if self ._api_key and "Authorization" not in request .headers :
38+ if "X-API-Key" not in request .headers :
39+ request .headers ["X-API-Key" ] = self ._api_key
40+ yield request
41+
1342
1443class AgentControlClient :
1544 """
@@ -45,6 +74,10 @@ def __init__(
4574 base_url : str | None = None ,
4675 timeout : float = 30.0 ,
4776 api_key : str | None = None ,
77+ runtime_auth_mode : RuntimeAuthMode | str | None = None ,
78+ runtime_token_cache : RuntimeTokenCache | None = None ,
79+ runtime_token_refresh_margin_seconds : int = (_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS ),
80+ transport : httpx .AsyncBaseTransport | None = None ,
4881 ):
4982 """
5083 Initialize the client.
@@ -55,13 +88,29 @@ def __init__(
5588 timeout: Request timeout in seconds
5689 api_key: API key for authentication. If not provided, will attempt
5790 to read from AGENT_CONTROL_API_KEY environment variable.
91+ runtime_auth_mode: Runtime auth mode for evaluation requests. ``auto``
92+ attempts target-bound JWT exchange and falls back to normal
93+ request auth when the exchange endpoint is unavailable. ``jwt``
94+ requires a successful exchange. ``api_key`` and ``none`` keep
95+ evaluation requests on the normal request-auth path.
96+ runtime_token_cache: Optional cache shared across client instances.
97+ runtime_token_refresh_margin_seconds: Refresh cached runtime tokens
98+ before this many seconds of validity remain.
99+ transport: Optional httpx transport, primarily for tests.
58100 """
59101 resolved_base_url = base_url or os .environ .get (
60102 self .BASE_URL_ENV_VAR , "http://localhost:8000"
61103 )
62104 self .base_url = resolved_base_url .rstrip ("/" )
63105 self .timeout = timeout
64106 self ._api_key = api_key or os .environ .get (self .API_KEY_ENV_VAR )
107+ configured_runtime_mode = runtime_auth_mode or os .environ .get (_RUNTIME_AUTH_MODE_ENV_VAR )
108+ self ._runtime_auth_mode = normalize_runtime_auth_mode (configured_runtime_mode )
109+ if runtime_token_refresh_margin_seconds < 0 :
110+ raise ValueError ("runtime_token_refresh_margin_seconds must be >= 0." )
111+ self ._runtime_token_refresh_margin_seconds = runtime_token_refresh_margin_seconds
112+ self ._runtime_token_cache = runtime_token_cache or RuntimeTokenCache ()
113+ self ._transport = transport
65114 self ._client : httpx .AsyncClient | None = None
66115 self ._server_version_warning_emitted = False
67116
@@ -70,15 +119,17 @@ def api_key(self) -> str | None:
70119 """Get the configured API key (read-only)."""
71120 return self ._api_key
72121
122+ @property
123+ def runtime_auth_mode (self ) -> RuntimeAuthMode :
124+ """Get the configured runtime auth mode (read-only)."""
125+ return self ._runtime_auth_mode
126+
73127 def _get_headers (self ) -> dict [str , str ]:
74- """Build request headers including authentication ."""
75- headers : dict [ str , str ] = {
128+ """Build base SDK metadata headers ."""
129+ return {
76130 "X-Agent-Control-SDK" : "python" ,
77131 "X-Agent-Control-SDK-Version" : sdk_version ,
78132 }
79- if self ._api_key :
80- headers ["X-API-Key" ] = self ._api_key
81- return headers
82133
83134 async def _check_server_version (self , response : httpx .Response ) -> None :
84135 """Warn once when the server major version differs from the SDK major."""
@@ -108,6 +159,8 @@ async def __aenter__(self) -> "AgentControlClient":
108159 base_url = self .base_url ,
109160 timeout = self .timeout ,
110161 headers = self ._get_headers (),
162+ auth = _AgentControlAuth (self ._api_key ),
163+ transport = self ._transport ,
111164 event_hooks = {"response" : [self ._check_server_version ]},
112165 )
113166 return self
@@ -137,6 +190,7 @@ async def health_check(self) -> dict[str, str]:
137190 response = await self ._client .get ("/health" )
138191 response .raise_for_status ()
139192 from typing import cast
193+
140194 return cast (dict [str , str ], response .json ())
141195
142196 @property
@@ -145,3 +199,130 @@ def http_client(self) -> httpx.AsyncClient:
145199 if self ._client is None :
146200 raise RuntimeError ("Client not initialized. Use 'async with' context manager." )
147201 return self ._client
202+
203+ async def post_runtime_evaluation (
204+ self ,
205+ * ,
206+ json : dict [str , Any ],
207+ headers : dict [str , str ] | None = None ,
208+ target_type : str | None = None ,
209+ target_id : str | None = None ,
210+ ) -> httpx .Response :
211+ """POST an evaluation request with runtime auth when configured."""
212+ runtime_authorization = await self ._runtime_authorization (
213+ target_type = target_type ,
214+ target_id = target_id ,
215+ )
216+ request_headers = self ._merge_runtime_headers (headers , runtime_authorization )
217+ response = await self .http_client .post (
218+ "/api/v1/evaluation" ,
219+ json = json ,
220+ headers = request_headers ,
221+ )
222+
223+ if response .status_code == 401 and runtime_authorization is not None :
224+ await response .aread ()
225+ runtime_authorization = await self ._runtime_authorization (
226+ target_type = target_type ,
227+ target_id = target_id ,
228+ force_refresh = True ,
229+ )
230+ request_headers = self ._merge_runtime_headers (headers , runtime_authorization )
231+ response = await self .http_client .post (
232+ "/api/v1/evaluation" ,
233+ json = json ,
234+ headers = request_headers ,
235+ )
236+
237+ return response
238+
239+ def _merge_runtime_headers (
240+ self ,
241+ headers : dict [str , str ] | None ,
242+ runtime_authorization : str | None ,
243+ ) -> dict [str , str ] | None :
244+ """Merge caller headers with an optional Bearer token."""
245+ if headers is None and runtime_authorization is None :
246+ return None
247+
248+ merged = dict (headers or {})
249+ if runtime_authorization is not None :
250+ merged ["Authorization" ] = runtime_authorization
251+ return merged
252+
253+ async def _runtime_authorization (
254+ self ,
255+ * ,
256+ target_type : str | None ,
257+ target_id : str | None ,
258+ force_refresh : bool = False ,
259+ ) -> str | None :
260+ """Return an Authorization header value for runtime evaluation."""
261+ if self ._runtime_auth_mode in {"none" , "api_key" }:
262+ return None
263+
264+ if target_type is None or target_id is None :
265+ if self ._runtime_auth_mode == "jwt" :
266+ raise RuntimeError (
267+ "runtime_auth_mode='jwt' requires target_type and target_id "
268+ "for evaluation requests."
269+ )
270+ return None
271+
272+ if self ._runtime_auth_mode == "auto" and self ._runtime_token_cache .is_jwt_unavailable (
273+ self .base_url , target_type , target_id
274+ ):
275+ return None
276+
277+ if not force_refresh :
278+ cached = self ._runtime_token_cache .get (
279+ self .base_url ,
280+ target_type ,
281+ target_id ,
282+ refresh_margin_seconds = self ._runtime_token_refresh_margin_seconds ,
283+ )
284+ if cached is not None :
285+ return f"Bearer { cached .token } "
286+
287+ token = await self ._exchange_runtime_token (
288+ target_type = target_type ,
289+ target_id = target_id ,
290+ )
291+ if token is None :
292+ return None
293+ return f"Bearer { token } "
294+
295+ async def _exchange_runtime_token (
296+ self ,
297+ * ,
298+ target_type : str ,
299+ target_id : str ,
300+ ) -> str | None :
301+ """Exchange the configured credential for a target-bound runtime token."""
302+ response = await self .http_client .post (
303+ "/api/v1/auth/runtime-token-exchange" ,
304+ json = {"target_type" : target_type , "target_id" : target_id },
305+ )
306+
307+ if (
308+ self ._runtime_auth_mode == "auto"
309+ and response .status_code in _AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES
310+ ):
311+ self ._runtime_token_cache .mark_jwt_unavailable (
312+ server_url = self .base_url ,
313+ target_type = target_type ,
314+ target_id = target_id ,
315+ globally = response .status_code in _GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES ,
316+ )
317+ return None
318+
319+ response .raise_for_status ()
320+ payload = response .json ()
321+ if not isinstance (payload , dict ):
322+ raise RuntimeError ("Runtime token exchange response was not an object." )
323+ token = parse_runtime_token_exchange_response (
324+ cast (dict [str , object ], payload ),
325+ server_url = self .base_url ,
326+ )
327+ self ._runtime_token_cache .set (token )
328+ return token .token
0 commit comments