Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions server/src/agent_control_server/auth_framework/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
_UPSTREAM_TOKEN_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN"
_UPSTREAM_TOKEN_HEADER_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER"
_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS"
_UPSTREAM_CA_FILE_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_CA_FILE"

# Runtime flow.
_RUNTIME_MODE_ENV = "AGENT_CONTROL_RUNTIME_AUTH_MODE"
Expand Down Expand Up @@ -216,6 +217,7 @@ def _build_default_provider() -> RequestAuthorizer:
extra_forward_headers = _parse_extra_forward_headers(
os.environ.get(_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV)
)
ca_file = (os.environ.get(_UPSTREAM_CA_FILE_ENV) or "").strip() or None
_logger.info("Default auth provider: http_upstream url=%s", url)
return HttpUpstreamAuthProvider(
HttpUpstreamConfig(
Expand All @@ -224,6 +226,7 @@ def _build_default_provider() -> RequestAuthorizer:
service_token=token,
service_token_header=token_header,
extra_forward_headers=extra_forward_headers,
ca_file=ca_file,
)
)
raise RuntimeError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from __future__ import annotations

import ssl
from dataclasses import dataclass
from datetime import datetime
from typing import Any
Expand Down Expand Up @@ -150,6 +151,9 @@ class HttpUpstreamConfig:
dropped. Names duplicating the default set or each other (after
case-folding) are deduplicated."""

ca_file: str | None = None
"""Optional CA bundle path used only when verifying the auth upstream."""

def __post_init__(self) -> None:
if self.service_token is None:
return
Expand All @@ -174,7 +178,16 @@ def __init__(
) -> None:
self._config = config
self._owns_client = client is None
self._client = client or httpx.AsyncClient(timeout=config.timeout_seconds)
if client is not None:
self._client = client
elif config.ca_file is not None:
ssl_context = ssl.create_default_context(cafile=config.ca_file)
self._client = httpx.AsyncClient(
timeout=config.timeout_seconds,
verify=ssl_context,
)
else:
self._client = httpx.AsyncClient(timeout=config.timeout_seconds)

async def aclose(self) -> None:
"""Release the HTTP client if this provider created it."""
Expand Down
74 changes: 72 additions & 2 deletions server/tests/test_auth_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import httpx
import pytest

from agent_control_server.auth_framework.core import (
Operation,
Principal,
Expand Down Expand Up @@ -227,6 +226,33 @@ def _build_upstream(
return HttpUpstreamAuthProvider(config, client=client)


def _patch_owned_upstream_client(monkeypatch) -> dict[str, Any]:
captured: dict[str, Any] = {}
ssl_context = object()

class FakeAsyncClient:
def __init__(self, **kwargs: Any) -> None:
captured.update(kwargs)

async def aclose(self) -> None:
captured["closed"] = True

def fake_create_default_context(*, cafile: str | None = None) -> object:
captured["cafile"] = cafile
return ssl_context

monkeypatch.setattr(
"agent_control_server.auth_framework.providers.http_upstream.httpx.AsyncClient",
FakeAsyncClient,
)
monkeypatch.setattr(
"agent_control_server.auth_framework.providers.http_upstream.ssl.create_default_context",
fake_create_default_context,
)
captured["ssl_context"] = ssl_context
return captured


@pytest.mark.asyncio
async def test_http_upstream_returns_principal_on_200():
captured: dict[str, Any] = {}
Expand Down Expand Up @@ -291,6 +317,26 @@ def test_http_upstream_rejects_extra_forwarded_service_token_header_collision():
)


@pytest.mark.asyncio
async def test_http_upstream_uses_ca_file_for_owned_client(monkeypatch):
captured = _patch_owned_upstream_client(monkeypatch)

provider = HttpUpstreamAuthProvider(
HttpUpstreamConfig(
url="https://upstream.example/check",
timeout_seconds=2.5,
ca_file="/etc/agent-control/auth-upstream-ca/ca.crt",
)
)

await provider.aclose()

assert captured["timeout"] == 2.5
assert captured["cafile"] == "/etc/agent-control/auth-upstream-ca/ca.crt"
assert captured["verify"] is captured["ssl_context"]
assert captured["closed"] is True


@pytest.mark.asyncio
async def test_http_upstream_forwards_extra_headers():
# Given: a provider configured with an extra header in its forward list
Expand Down Expand Up @@ -772,7 +818,6 @@ def test_runtime_token_rejects_empty_required_claims(kwargs, message):
def test_runtime_token_rejects_management_token_passed_to_runtime_verify():
"""A token without ``domain=runtime`` must be rejected by runtime verify."""
import jwt

from agent_control_server.auth_framework.runtime_token import (
RuntimeTokenError,
verify_runtime_token,
Expand Down Expand Up @@ -1422,6 +1467,31 @@ async def test_configure_http_upstream_extra_forward_headers_env(monkeypatch):
await auth_config.teardown_auth()


@pytest.mark.asyncio
async def test_configure_http_upstream_ca_file_env(monkeypatch):
from agent_control_server.auth_framework import config as auth_config

clear_authorizers()
captured = _patch_owned_upstream_client(monkeypatch)

monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream")
monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check")
monkeypatch.setenv(
"AGENT_CONTROL_AUTH_UPSTREAM_CA_FILE",
" /etc/agent-control/auth-upstream-ca/ca.crt ",
)

try:
auth_config.configure_auth_from_env()
provider = get_authorizer(Operation.CONTROLS_READ)
assert isinstance(provider, HttpUpstreamAuthProvider)
assert provider._config.ca_file == "/etc/agent-control/auth-upstream-ca/ca.crt"
assert captured["cafile"] == "/etc/agent-control/auth-upstream-ca/ca.crt"
assert captured["verify"] is captured["ssl_context"]
finally:
await auth_config.teardown_auth()


def test_configure_runtime_jwt_requires_secret(monkeypatch):
from agent_control_server.auth_framework import config as auth_config

Expand Down
Loading