diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 8d8fbcd9..cf5fb67f 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -48,6 +48,7 @@ _UPSTREAM_TOKEN_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN" _UPSTREAM_TOKEN_HEADER_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER" _UPSTREAM_EXTRA_FORWARD_HEADERS_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS" +_UPSTREAM_CA_FILE_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_CA_FILE" # Runtime flow. _RUNTIME_MODE_ENV = "AGENT_CONTROL_RUNTIME_AUTH_MODE" @@ -216,6 +217,7 @@ def _build_default_provider() -> RequestAuthorizer: extra_forward_headers = _parse_extra_forward_headers( os.environ.get(_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV) ) + ca_file = (os.environ.get(_UPSTREAM_CA_FILE_ENV) or "").strip() or None _logger.info("Default auth provider: http_upstream url=%s", url) return HttpUpstreamAuthProvider( HttpUpstreamConfig( @@ -224,6 +226,7 @@ def _build_default_provider() -> RequestAuthorizer: service_token=token, service_token_header=token_header, extra_forward_headers=extra_forward_headers, + ca_file=ca_file, ) ) raise RuntimeError( diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index f6ce3b5f..cbd2cb18 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -41,6 +41,7 @@ from __future__ import annotations +import ssl from dataclasses import dataclass from datetime import datetime from typing import Any @@ -150,6 +151,9 @@ class HttpUpstreamConfig: dropped. Names duplicating the default set or each other (after case-folding) are deduplicated.""" + ca_file: str | None = None + """Optional CA bundle path used only when verifying the auth upstream.""" + def __post_init__(self) -> None: if self.service_token is None: return @@ -174,7 +178,16 @@ def __init__( ) -> None: self._config = config self._owns_client = client is None - self._client = client or httpx.AsyncClient(timeout=config.timeout_seconds) + if client is not None: + self._client = client + elif config.ca_file is not None: + ssl_context = ssl.create_default_context(cafile=config.ca_file) + self._client = httpx.AsyncClient( + timeout=config.timeout_seconds, + verify=ssl_context, + ) + else: + self._client = httpx.AsyncClient(timeout=config.timeout_seconds) async def aclose(self) -> None: """Release the HTTP client if this provider created it.""" diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 5f31c52f..874d0317 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -7,7 +7,6 @@ import httpx import pytest - from agent_control_server.auth_framework.core import ( Operation, Principal, @@ -227,6 +226,33 @@ def _build_upstream( return HttpUpstreamAuthProvider(config, client=client) +def _patch_owned_upstream_client(monkeypatch) -> dict[str, Any]: + captured: dict[str, Any] = {} + ssl_context = object() + + class FakeAsyncClient: + def __init__(self, **kwargs: Any) -> None: + captured.update(kwargs) + + async def aclose(self) -> None: + captured["closed"] = True + + def fake_create_default_context(*, cafile: str | None = None) -> object: + captured["cafile"] = cafile + return ssl_context + + monkeypatch.setattr( + "agent_control_server.auth_framework.providers.http_upstream.httpx.AsyncClient", + FakeAsyncClient, + ) + monkeypatch.setattr( + "agent_control_server.auth_framework.providers.http_upstream.ssl.create_default_context", + fake_create_default_context, + ) + captured["ssl_context"] = ssl_context + return captured + + @pytest.mark.asyncio async def test_http_upstream_returns_principal_on_200(): captured: dict[str, Any] = {} @@ -291,6 +317,26 @@ def test_http_upstream_rejects_extra_forwarded_service_token_header_collision(): ) +@pytest.mark.asyncio +async def test_http_upstream_uses_ca_file_for_owned_client(monkeypatch): + captured = _patch_owned_upstream_client(monkeypatch) + + provider = HttpUpstreamAuthProvider( + HttpUpstreamConfig( + url="https://upstream.example/check", + timeout_seconds=2.5, + ca_file="/etc/agent-control/auth-upstream-ca/ca.crt", + ) + ) + + await provider.aclose() + + assert captured["timeout"] == 2.5 + assert captured["cafile"] == "/etc/agent-control/auth-upstream-ca/ca.crt" + assert captured["verify"] is captured["ssl_context"] + assert captured["closed"] is True + + @pytest.mark.asyncio async def test_http_upstream_forwards_extra_headers(): # Given: a provider configured with an extra header in its forward list @@ -772,7 +818,6 @@ def test_runtime_token_rejects_empty_required_claims(kwargs, message): def test_runtime_token_rejects_management_token_passed_to_runtime_verify(): """A token without ``domain=runtime`` must be rejected by runtime verify.""" import jwt - from agent_control_server.auth_framework.runtime_token import ( RuntimeTokenError, verify_runtime_token, @@ -1422,6 +1467,31 @@ async def test_configure_http_upstream_extra_forward_headers_env(monkeypatch): await auth_config.teardown_auth() +@pytest.mark.asyncio +async def test_configure_http_upstream_ca_file_env(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + captured = _patch_owned_upstream_client(monkeypatch) + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.setenv( + "AGENT_CONTROL_AUTH_UPSTREAM_CA_FILE", + " /etc/agent-control/auth-upstream-ca/ca.crt ", + ) + + try: + auth_config.configure_auth_from_env() + provider = get_authorizer(Operation.CONTROLS_READ) + assert isinstance(provider, HttpUpstreamAuthProvider) + assert provider._config.ca_file == "/etc/agent-control/auth-upstream-ca/ca.crt" + assert captured["cafile"] == "/etc/agent-control/auth-upstream-ca/ca.crt" + assert captured["verify"] is captured["ssl_context"] + finally: + await auth_config.teardown_auth() + + def test_configure_runtime_jwt_requires_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config