Skip to content

Commit e470cac

Browse files
fix(sdk): propagate custom API key header
1 parent f68bf36 commit e470cac

7 files changed

Lines changed: 87 additions & 6 deletions

File tree

sdks/python/src/agent_control/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class _RefreshContext:
160160
agent_name: str
161161
server_url: str
162162
api_key: str | None
163+
api_key_header: str | None
163164
target_type: str | None
164165
target_id: str | None
165166

@@ -221,6 +222,7 @@ def _snapshot_refresh_context() -> _RefreshContext:
221222
agent = state.current_agent
222223
server_url = state.server_url
223224
api_key = state.api_key
225+
api_key_header = state.api_key_header
224226
target_type = state.target_type
225227
target_id = state.target_id
226228

@@ -234,6 +236,7 @@ def _snapshot_refresh_context() -> _RefreshContext:
234236
agent_name=agent.agent_name,
235237
server_url=server_url,
236238
api_key=api_key,
239+
api_key_header=api_key_header,
237240
target_type=target_type,
238241
target_id=target_id,
239242
)
@@ -244,6 +247,7 @@ async def _fetch_controls_for_context_async(context: _RefreshContext) -> list[di
244247
async with AgentControlClient(
245248
base_url=context.server_url,
246249
api_key=context.api_key,
250+
api_key_header=context.api_key_header,
247251
) as client:
248252
response = await agents.list_agent_controls(
249253
client,
@@ -430,6 +434,7 @@ def init(
430434
agent_version: str | None = None,
431435
server_url: str | None = None,
432436
api_key: str | None = None,
437+
api_key_header: str | None = None,
433438
controls_file: str | None = None,
434439
steps: list[StepSchemaDict] | None = None,
435440
conflict_mode: Literal["strict", "overwrite"] = "overwrite",
@@ -468,6 +473,8 @@ def init(
468473
server_url: Optional server URL (defaults to AGENT_CONTROL_URL env var
469474
or http://localhost:8000)
470475
api_key: Optional API key for authentication (defaults to AGENT_CONTROL_API_KEY env var)
476+
api_key_header: Optional HTTP header name for API key authentication
477+
(defaults to AGENT_CONTROL_API_KEY_HEADER env var or X-API-Key)
471478
controls_file: Optional explicit path to controls.yaml (auto-discovered if not provided)
472479
steps: Optional list of step schemas for registration:
473480
[{"type": "tool", "name": "search", "input_schema": {...}, "output_schema": {...}}]
@@ -562,6 +569,7 @@ async def handle(message: str):
562569
state.current_agent = next_agent
563570
state.server_url = server_url or os.getenv('AGENT_CONTROL_URL') or 'http://localhost:8000'
564571
state.api_key = api_key
572+
state.api_key_header = api_key_header
565573
state.runtime_token_cache.clear()
566574
state.target_type = target_type
567575
state.target_id = target_id
@@ -600,6 +608,7 @@ async def register() -> list[dict[str, Any]] | None:
600608
async with AgentControlClient(
601609
base_url=state.server_url,
602610
api_key=state.api_key,
611+
api_key_header=state.api_key_header,
603612
) as client:
604613
# Check server health first
605614
try:
@@ -686,6 +695,7 @@ def run_in_thread() -> None:
686695
batcher = init_observability(
687696
server_url=state.server_url,
688697
api_key=state.api_key,
698+
api_key_header=state.api_key_header,
689699
enabled=observability_enabled,
690700
sink_name=observability_sink_name,
691701
sink_config=observability_sink_config,
@@ -717,6 +727,7 @@ def _reset_state() -> None:
717727
state.server_controls = None
718728
state.server_url = None
719729
state.api_key = None
730+
state.api_key_header = None
720731
state.runtime_token_cache.clear()
721732
state.target_type = None
722733
state.target_id = None

sdks/python/src/agent_control/_state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self) -> None:
2626
self.server_controls: list[dict[str, Any]] | None = None
2727
self.server_url: str | None = None
2828
self.api_key: str | None = None
29+
self.api_key_header: str | None = None
2930
self.runtime_token_cache = RuntimeTokenCache()
3031
# Optional target context fixed at init() time; both fields are set
3132
# together or both remain None.

sdks/python/src/agent_control/evaluation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ async def evaluate_controls(
557557
async with AgentControlClient(
558558
base_url=state.server_url,
559559
api_key=state.api_key,
560+
api_key_header=state.api_key_header,
560561
runtime_token_cache=state.runtime_token_cache,
561562
) as client:
562563
return await check_evaluation_with_local(

sdks/python/src/agent_control/integrations/google_adk/plugin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,11 @@ async def _sync_steps_async(self, steps: list[StepSchemaDict]) -> None:
853853
"with the same agent_name as AgentControlPlugin."
854854
)
855855

856-
async with AgentControlClient(base_url=state.server_url, api_key=state.api_key) as client:
856+
async with AgentControlClient(
857+
base_url=state.server_url,
858+
api_key=state.api_key,
859+
api_key_header=state.api_key_header,
860+
) as client:
857861
response = await agents.get_agent(client, self.agent_name)
858862
existing = GetAgentResponse.model_validate(response)
859863
existing_keys = {(step.type, step.name) for step in existing.steps}

sdks/python/tests/test_evaluation.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from uuid import UUID
55

66
import pytest
7-
from pydantic import ValidationError
8-
97
from agent_control import evaluation
108
from agent_control.evaluation import EvaluationResult
9+
from pydantic import ValidationError
1110

1211

1312
@pytest.mark.asyncio
@@ -126,6 +125,27 @@ async def test_evaluate_controls_with_context(monkeypatch):
126125
assert mock_check.call_args is not None
127126

128127

128+
@pytest.mark.asyncio
129+
async def test_evaluate_controls_uses_session_api_key_header(monkeypatch):
130+
"""evaluate_controls should pass init's API-key header into the client."""
131+
mock_result = EvaluationResult(is_safe=True, confidence=1.0)
132+
mock_check = AsyncMock(return_value=mock_result)
133+
monkeypatch.setattr(evaluation, "check_evaluation_with_local", mock_check)
134+
135+
with patch("agent_control.state.server_url", "http://localhost:8000"), patch(
136+
"agent_control.state.api_key", "test-key"
137+
), patch("agent_control.state.api_key_header", "Galileo-API-Key"):
138+
await evaluation.evaluate_controls(
139+
step_name="chat",
140+
input="hello",
141+
stage="pre",
142+
agent_name="test-bot",
143+
)
144+
145+
client = mock_check.call_args.kwargs["client"]
146+
assert client.api_key_header == "Galileo-API-Key"
147+
148+
129149
@pytest.mark.asyncio
130150
async def test_check_evaluation_forwards_target_context():
131151
"""When target_type and target_id are supplied, they are forwarded to the server."""

sdks/python/tests/test_init_step_merge.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import logging
66
from collections.abc import Generator
7-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Any
88
from unittest.mock import ANY, AsyncMock, patch
99
from uuid import uuid4
1010

@@ -219,6 +219,47 @@ def test_init_omits_merge_events_from_public_signature() -> None:
219219
assert "merge_events" not in signature.parameters
220220

221221

222+
def test_init_passes_api_key_header_to_client_and_state() -> None:
223+
register_agent_mock = AsyncMock(return_value={"created": True, "controls": []})
224+
health_check_mock = AsyncMock(return_value={"status": "healthy"})
225+
client_init_kwargs: list[dict[str, Any]] = []
226+
original_init = agent_control.AgentControlClient.__init__
227+
228+
def recording_init(
229+
self: agent_control.AgentControlClient,
230+
*args: Any,
231+
**kwargs: Any,
232+
) -> None:
233+
client_init_kwargs.append(dict(kwargs))
234+
original_init(self, *args, **kwargs)
235+
236+
with patch.object(
237+
agent_control.AgentControlClient,
238+
"__init__",
239+
new=recording_init,
240+
), patch(
241+
"agent_control.__init__.AgentControlClient.health_check",
242+
new=health_check_mock,
243+
), patch(
244+
"agent_control.__init__.agents.register_agent",
245+
new=register_agent_mock,
246+
), patch.object(
247+
agent_control,
248+
"init_observability",
249+
return_value=None,
250+
) as observability_mock:
251+
agent_control.init(
252+
agent_name=f"agent-{uuid4().hex[:12]}",
253+
api_key="test-key",
254+
api_key_header="Galileo-API-Key",
255+
policy_refresh_interval_seconds=0,
256+
)
257+
258+
assert agent_control.state.api_key_header == "Galileo-API-Key"
259+
assert client_init_kwargs[0]["api_key_header"] == "Galileo-API-Key"
260+
assert observability_mock.call_args.kwargs["api_key_header"] == "Galileo-API-Key"
261+
262+
222263
@pytest.mark.asyncio
223264
async def test_refresh_controls_calls_agent_controls_endpoint() -> None:
224265
# Given: an initialized SDK agent session with network-facing calls mocked.
@@ -238,6 +279,7 @@ async def test_refresh_controls_calls_agent_controls_endpoint() -> None:
238279
):
239280
agent_control.init(
240281
agent_name=f"agent-{uuid4().hex[:12]}",
282+
api_key_header="Galileo-API-Key",
241283
policy_refresh_interval_seconds=0,
242284
)
243285

@@ -255,4 +297,5 @@ async def test_refresh_controls_calls_agent_controls_endpoint() -> None:
255297
target_type=None,
256298
target_id=None,
257299
)
300+
assert agent_control.state.api_key_header == "Galileo-API-Key"
258301
assert register_agent_mock.await_count == 0

sdks/python/tests/test_shutdown.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99
from typing import Any
1010
from unittest.mock import AsyncMock, MagicMock, patch
1111

12-
import pytest
13-
1412
import agent_control
1513
import agent_control.observability as obs_mod
14+
import pytest
1615
from agent_control._state import state
1716
from agent_control.observability import EventBatcher
1817

@@ -64,6 +63,7 @@ def test_shutdown_resets_state(self):
6463
state.server_controls = [{"name": "test"}]
6564
state.server_url = "http://localhost:8000"
6665
state.api_key = "key"
66+
state.api_key_header = "X-Custom-API-Key"
6767

6868
agent_control.shutdown()
6969

@@ -72,6 +72,7 @@ def test_shutdown_resets_state(self):
7272
assert state.server_controls is None
7373
assert state.server_url is None
7474
assert state.api_key is None
75+
assert state.api_key_header is None
7576

7677
def test_shutdown_idempotent(self):
7778
agent_control.shutdown()

0 commit comments

Comments
 (0)