From fbfcf4e37b5466b47c5161ba51fbc2fd10d3c9f1 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 20:14:39 +0530 Subject: [PATCH 01/20] feat(server): add runtime auth namespace cutover Add explicit none, api_key, and jwt runtime auth modes, including a generic no-auth provider. Move controls, bindings, policies, agents, and evaluation storage lookups onto principal namespace scoping. Cover auth mode selection and principal namespace isolation with server tests. --- .../auth_framework/__init__.py | 7 +- .../auth_framework/config.py | 120 +++++++++++---- .../auth_framework/core.py | 16 +- .../auth_framework/providers/__init__.py | 2 + .../auth_framework/providers/header.py | 39 +++-- .../auth_framework/providers/http_upstream.py | 6 +- .../auth_framework/providers/local_jwt.py | 2 +- .../auth_framework/providers/no_auth.py | 29 ++++ .../agent_control_server/endpoints/agents.py | 92 +++++++----- .../agent_control_server/endpoints/auth.py | 11 +- .../endpoints/control_bindings.py | 47 +++--- .../endpoints/controls.py | 87 +++++++---- .../endpoints/evaluation.py | 26 +++- .../endpoints/policies.py | 77 +++++++--- server/src/agent_control_server/main.py | 19 ++- .../agent_control_server/services/controls.py | 140 +++++++++++++---- server/tests/test_auth_framework.py | 96 +++++++++++- server/tests/test_controls_additional.py | 15 +- server/tests/test_controls_auth.py | 27 ++-- server/tests/test_principal_namespace_flow.py | 141 ++++++++++++++++++ server/tests/test_target_merged_contract.py | 6 +- 21 files changed, 753 insertions(+), 252 deletions(-) create mode 100644 server/src/agent_control_server/auth_framework/providers/no_auth.py create mode 100644 server/tests/test_principal_namespace_flow.py diff --git a/server/src/agent_control_server/auth_framework/__init__.py b/server/src/agent_control_server/auth_framework/__init__.py index 57368d57..0333f2cc 100644 --- a/server/src/agent_control_server/auth_framework/__init__.py +++ b/server/src/agent_control_server/auth_framework/__init__.py @@ -2,10 +2,9 @@ Endpoints declare an :class:`Operation` they need; an installed :class:`RequestAuthorizer` decides whether the request is allowed and -returns the resulting :class:`Principal`. Two providers ship in-tree: -:class:`HeaderAuthProvider` (uses local credential checks) and -:class:`HttpUpstreamAuthProvider` (delegates to a configurable -upstream HTTP service). +returns the resulting :class:`Principal`. Providers ship in-tree for +disabled auth, local credential checks, upstream HTTP authorization, +and local runtime-JWT verification. """ from .core import ( diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 92107b0e..c8f428dc 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -8,15 +8,19 @@ - **Default flow** (everything except runtime). One authorizer handles every operation that does not have a specific override: - :class:`HeaderAuthProvider` (local credentials) or + :class:`NoAuthProvider` (no credentials), + :class:`HeaderAuthProvider` (local API keys), or :class:`HttpUpstreamAuthProvider` (forwards to a configurable URL). -- **Runtime flow.** When ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is - configured, :class:`LocalJwtVerifyProvider` is registered as the - override for :data:`Operation.RUNTIME_USE`; the - ``runtime.token_exchange`` operation continues to flow through the - default authorizer because the exchange itself is shaped like a - management call (forward credential, get grant). Without the secret, - no runtime override is installed. +- **Runtime flow.** ``AGENT_CONTROL_RUNTIME_AUTH_MODE`` selects the + override for :data:`Operation.RUNTIME_USE`: ``none`` uses + :class:`NoAuthProvider`, ``api_key`` uses + :class:`HeaderAuthProvider`, and ``jwt`` uses + :class:`LocalJwtVerifyProvider`. When the mode is unset, startup + preserves historical behavior by selecting ``jwt`` if + ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, otherwise ``api_key``. + The ``runtime.token_exchange`` operation continues to flow through + the default authorizer because the exchange itself is shaped like a + management call (forward credential, get grant). """ from __future__ import annotations @@ -30,6 +34,7 @@ HeaderAuthProvider, HttpUpstreamAuthProvider, LocalJwtVerifyProvider, + NoAuthProvider, ) from .providers.http_upstream import HttpUpstreamConfig @@ -43,6 +48,7 @@ _UPSTREAM_TOKEN_HEADER_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER" # Runtime flow. +_RUNTIME_MODE_ENV = "AGENT_CONTROL_RUNTIME_AUTH_MODE" _RUNTIME_TOKEN_SECRET_ENV = "AGENT_CONTROL_RUNTIME_TOKEN_SECRET" _RUNTIME_TOKEN_TTL_ENV = "AGENT_CONTROL_RUNTIME_TOKEN_TTL_SECONDS" _DEFAULT_RUNTIME_TOKEN_TTL_SECONDS = 300 @@ -80,15 +86,19 @@ def configure_auth_from_env() -> None: Default flow: - - ``AGENT_CONTROL_AUTH_MODE=header`` (default): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_AUTH_MODE=none``: :class:`NoAuthProvider`. + - ``AGENT_CONTROL_AUTH_MODE=api_key`` (default): :class:`HeaderAuthProvider`. + ``header`` remains accepted as a backwards-compatible alias. - ``AGENT_CONTROL_AUTH_MODE=http_upstream``: :class:`HttpUpstreamAuthProvider` pointed at ``AGENT_CONTROL_AUTH_UPSTREAM_URL``. Runtime flow: - - When ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, register - :class:`LocalJwtVerifyProvider` as an override for - :data:`Operation.RUNTIME_USE`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=none``: :class:`NoAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key`` (default when no runtime + token secret is configured): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=jwt`` (default when a runtime token + secret is configured): :class:`LocalJwtVerifyProvider`. Clears any previously-installed default and operation overrides before installing fresh ones, so reconfiguration cannot leave @@ -101,27 +111,27 @@ def configure_auth_from_env() -> None: global _runtime_auth_config clear_authorizers() _active_providers.clear() - _runtime_auth_config = _load_runtime_auth_config() + runtime_mode = _resolve_runtime_mode() + _runtime_auth_config = ( + _load_runtime_auth_config(require_secret=True) if runtime_mode == "jwt" else None + ) default = _build_default_provider() set_authorizer(default) _active_providers.append(default) - if _runtime_auth_config is not None: - runtime_provider = LocalJwtVerifyProvider(secret=_runtime_auth_config.secret) - set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) - _active_providers.append(runtime_provider) + runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) + set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) + _active_providers.append(runtime_provider) + if runtime_mode == "jwt": _logger.info( - "Runtime auth enabled: LocalJwtVerifyProvider override installed for %s", + "Runtime auth provider: jwt override installed for %s", Operation.RUNTIME_USE.value, ) else: - _logger.warning( - "Runtime auth disabled (%s not set); %s falls through to the " - "default authorizer, which may grant any authenticated credential. " - "Set the runtime token secret to bind runtime calls to a " - "short-lived target-scoped JWT.", - _RUNTIME_TOKEN_SECRET_ENV, + _logger.info( + "Runtime auth provider: %s override installed for %s", + runtime_mode, Operation.RUNTIME_USE.value, ) @@ -172,9 +182,12 @@ def set_runtime_auth_config(config: RuntimeAuthConfig | None) -> None: def _build_default_provider() -> RequestAuthorizer: - mode = os.environ.get(_MODE_ENV, "header").strip().lower() - if mode == "header": - _logger.info("Default auth provider: header (local credentials)") + mode = os.environ.get(_MODE_ENV, "api_key").strip().lower() + if mode in {"none", "no_auth"}: + _logger.info("Default auth provider: none") + return NoAuthProvider() + if mode in {"api_key", "header"}: + _logger.info("Default auth provider: api_key (local credentials)") return HeaderAuthProvider() if mode == "http_upstream": url = os.environ.get(_UPSTREAM_URL_ENV) @@ -192,19 +205,60 @@ def _build_default_provider() -> RequestAuthorizer: service_token_header=token_header, ) ) - raise RuntimeError(f"Unknown {_MODE_ENV}={mode!r}; expected 'header' or 'http_upstream'.") + raise RuntimeError( + f"Unknown {_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'http_upstream'." + ) + + +def _resolve_runtime_mode() -> str: + raw = os.environ.get(_RUNTIME_MODE_ENV) + if raw is None or not raw.strip(): + return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "api_key" + + mode = raw.strip().lower() + if mode in {"none", "no_auth"}: + return "none" + if mode in {"api_key", "header"}: + return "api_key" + if mode == "jwt": + return mode + raise RuntimeError( + f"Unknown {_RUNTIME_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'jwt'." + ) + + +def _build_runtime_provider( + mode: str, + config: RuntimeAuthConfig | None, +) -> RequestAuthorizer: + if mode == "none": + return NoAuthProvider() + if mode == "api_key": + return HeaderAuthProvider() + if mode == "jwt": + if config is None: + raise RuntimeError(f"{_RUNTIME_MODE_ENV}=jwt but runtime auth config is missing.") + return LocalJwtVerifyProvider(secret=config.secret) + raise RuntimeError( + f"Unknown runtime auth mode {mode!r}; expected 'none', 'api_key', or 'jwt'." + ) -def _load_runtime_auth_config() -> RuntimeAuthConfig | None: +def _load_runtime_auth_config(*, require_secret: bool = False) -> RuntimeAuthConfig | None: """Parse, validate, and return the runtime-auth config from env. - Returns ``None`` when no runtime secret is configured. Raises - ``RuntimeError`` when the secret is too short or the TTL is invalid - so misconfiguration surfaces at startup, not on the first - request-time mint. + Returns ``None`` when no runtime secret is configured and + ``require_secret`` is false. Raises ``RuntimeError`` when the + secret is required, too short, or the TTL is invalid so + misconfiguration surfaces at startup, not on the first request-time + mint. """ secret = os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) if not secret: + if require_secret: + raise RuntimeError( + f"{_RUNTIME_MODE_ENV}=jwt requires {_RUNTIME_TOKEN_SECRET_ENV} to be set." + ) return None if len(secret.encode("utf-8")) < _RUNTIME_TOKEN_SECRET_MIN_BYTES: raise RuntimeError( diff --git a/server/src/agent_control_server/auth_framework/core.py b/server/src/agent_control_server/auth_framework/core.py index 9299b441..e0ea6da7 100644 --- a/server/src/agent_control_server/auth_framework/core.py +++ b/server/src/agent_control_server/auth_framework/core.py @@ -42,14 +42,21 @@ class Operation(StrEnum): CONTROL_BINDINGS_READ = "control_bindings.read" CONTROL_BINDINGS_WRITE = "control_bindings.write" - # Runtime token exchange — wired on the exchange endpoint. + # Runtime token exchange - wired on the exchange endpoint. RUNTIME_TOKEN_EXCHANGE = "runtime.token_exchange" - # Reserved for follow-up migrations; not yet wired on endpoints. CONTROLS_READ = "controls.read" CONTROLS_CREATE = "controls.create" CONTROLS_UPDATE = "controls.update" CONTROLS_DELETE = "controls.delete" + POLICIES_READ = "policies.read" + POLICIES_CREATE = "policies.create" + POLICIES_UPDATE = "policies.update" + POLICIES_DELETE = "policies.delete" + AGENTS_READ = "agents.read" + AGENTS_CREATE = "agents.create" + AGENTS_UPDATE = "agents.update" + AGENTS_DELETE = "agents.delete" RUNTIME_USE = "runtime.use" @@ -61,8 +68,7 @@ class Principal: namespace_key: The namespace the request runs in. Endpoints use this to scope every read and write. is_admin: Whether the caller has admin privileges in the - current namespace. Mostly informational for endpoints that - still gate on the legacy admin-key contract. + current namespace. caller_id: Opaque, provider-supplied identifier for the caller (e.g., a key fingerprint or user id). Useful for audit logging; never echo back to clients. @@ -122,7 +128,7 @@ def set_authorizer( Without ``operation``, this becomes the default authorizer used by every operation that does not have a specific override. With - ``operation``, it overrides the default for that operation only — + ``operation``, it overrides the default for that operation only - used to route a different family (e.g., runtime) through a different provider. diff --git a/server/src/agent_control_server/auth_framework/providers/__init__.py b/server/src/agent_control_server/auth_framework/providers/__init__.py index e8a68486..ad5d6b38 100644 --- a/server/src/agent_control_server/auth_framework/providers/__init__.py +++ b/server/src/agent_control_server/auth_framework/providers/__init__.py @@ -3,10 +3,12 @@ from .header import AccessLevel, HeaderAuthProvider from .http_upstream import HttpUpstreamAuthProvider from .local_jwt import LocalJwtVerifyProvider +from .no_auth import NoAuthProvider __all__ = [ "AccessLevel", "HeaderAuthProvider", "HttpUpstreamAuthProvider", "LocalJwtVerifyProvider", + "NoAuthProvider", ] diff --git a/server/src/agent_control_server/auth_framework/providers/header.py b/server/src/agent_control_server/auth_framework/providers/header.py index f76936a1..228ec443 100644 --- a/server/src/agent_control_server/auth_framework/providers/header.py +++ b/server/src/agent_control_server/auth_framework/providers/header.py @@ -1,23 +1,14 @@ """Default :class:`RequestAuthorizer` that uses local credentials only. -Resolves the namespace from a header (or falls back to -``DEFAULT_NAMESPACE_KEY``) and enforces a per-operation access level -using the legacy API-key + session-cookie credential check from -:mod:`agent_control_server.auth`. Behavior matches the pre-framework -local auth path verbatim: +Returns ``DEFAULT_NAMESPACE_KEY`` and enforces a per-operation access +level using the local API-key + session-cookie credential check from +:mod:`agent_control_server.auth`: - ``ADMIN`` operations require an admin key (or admin session). - ``AUTHENTICATED`` operations require any valid credential. - ``PUBLIC`` operations are open. -- When ``api_key_enabled`` is ``False`` (no-auth mode), every - operation succeeds with a non-admin :class:`Principal` — preserved - by the underlying credential check. - -The header lookup is wired but currently inert: the provider always -returns the default namespace because non-binding write endpoints -still hardcode it. The header is kept here so a follow-up that -threads namespace resolution through the rest of the API can flip it -on without changing the provider contract. +- When the underlying local credential layer is disabled, every + operation succeeds with a non-admin :class:`Principal`. """ from __future__ import annotations @@ -51,6 +42,14 @@ class AccessLevel(Enum): Operation.CONTROLS_CREATE: AccessLevel.ADMIN, Operation.CONTROLS_UPDATE: AccessLevel.ADMIN, Operation.CONTROLS_DELETE: AccessLevel.ADMIN, + Operation.POLICIES_READ: AccessLevel.AUTHENTICATED, + Operation.POLICIES_CREATE: AccessLevel.ADMIN, + Operation.POLICIES_UPDATE: AccessLevel.ADMIN, + Operation.POLICIES_DELETE: AccessLevel.ADMIN, + Operation.AGENTS_READ: AccessLevel.AUTHENTICATED, + Operation.AGENTS_CREATE: AccessLevel.AUTHENTICATED, + Operation.AGENTS_UPDATE: AccessLevel.ADMIN, + Operation.AGENTS_DELETE: AccessLevel.ADMIN, Operation.RUNTIME_TOKEN_EXCHANGE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_USE: AccessLevel.AUTHENTICATED, } @@ -60,7 +59,7 @@ class HeaderAuthProvider(RequestAuthorizer): """Default authorizer. For each operation's configured access level, validates the - request's credentials via the legacy local check; on success, + request's credentials via the local credential check; on success, returns a :class:`Principal` scoped to the resolved namespace. """ @@ -100,8 +99,7 @@ async def authorize( ) # Runtime token exchange returns a normalized scope grant so the # exchange endpoint can require ``runtime.use`` uniformly across - # providers; an upstream that explicitly grants no scopes ends - # up with an empty tuple and is rejected. + # providers. scopes: tuple[str, ...] = ( (Operation.RUNTIME_USE.value,) if operation is Operation.RUNTIME_TOKEN_EXCHANGE else () ) @@ -113,10 +111,7 @@ async def authorize( ) def _resolve_namespace_key(self, request: Request) -> str: - # The provider always returns the default namespace because - # non-binding write endpoints still hardcode it; serving - # anything else here would create rows the rest of the API - # cannot find. The branch is preserved so a future change can - # lift the lock without touching the provider contract. + # Local credentials do not carry namespace metadata. Providers + # that resolve a namespace can return a different principal. del request return self._default_namespace_key diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index a97a3de8..8d5c850c 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -67,8 +67,8 @@ class _UpstreamGrant(BaseModel): """Strict schema for the upstream authorization-service response. Unknown fields are tolerated (so the upstream can evolve), but every - *known* field is type-checked. A wrong type on any field — or a - half-supplied target binding — causes the provider to fail closed + *known* field is type-checked. A wrong type on any field - or a + half-supplied target binding - causes the provider to fail closed with a 502. """ @@ -108,7 +108,7 @@ def _target_must_be_paired(self) -> _UpstreamGrant: A target is meaningful only as a ``(target_type, target_id)`` pair; allowing one side without the other would let a malformed grant pass and the exchange endpoint mint a token for the - request's value of the missing half — outside the upstream's + request's value of the missing half - outside the upstream's intended authorization. """ if (self.target_type is None) != (self.target_id is None): diff --git a/server/src/agent_control_server/auth_framework/providers/local_jwt.py b/server/src/agent_control_server/auth_framework/providers/local_jwt.py index bb448503..8620d3b6 100644 --- a/server/src/agent_control_server/auth_framework/providers/local_jwt.py +++ b/server/src/agent_control_server/auth_framework/providers/local_jwt.py @@ -6,7 +6,7 @@ returns a :class:`Principal` carrying the bound target. When a ``context_builder`` on the dependency surfaces ``target_type`` / ``target_id``, the provider also enforces that they match the token's -binding — runtime endpoints get the request-target check for free. +binding - runtime endpoints get the request-target check for free. """ from __future__ import annotations diff --git a/server/src/agent_control_server/auth_framework/providers/no_auth.py b/server/src/agent_control_server/auth_framework/providers/no_auth.py new file mode 100644 index 00000000..509ca4f3 --- /dev/null +++ b/server/src/agent_control_server/auth_framework/providers/no_auth.py @@ -0,0 +1,29 @@ +"""Authorizer for deployments that intentionally disable authentication.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import Request + +from ...models import DEFAULT_NAMESPACE_KEY +from ..core import Operation, Principal, RequestAuthorizer + + +class NoAuthProvider(RequestAuthorizer): + """Allows every operation and returns the default namespace.""" + + def __init__(self, *, default_namespace_key: str = DEFAULT_NAMESPACE_KEY) -> None: + self._default_namespace_key = default_namespace_key + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request, context + scopes: tuple[str, ...] = ( + (Operation.RUNTIME_USE.value,) if operation is Operation.RUNTIME_TOKEN_EXCHANGE else () + ) + return Principal(namespace_key=self._default_namespace_key, scopes=scopes) diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index 034ae35f..ac099911 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -36,7 +36,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import RequireAPIKey, require_admin_key +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import ( APIValidationError, @@ -53,7 +53,6 @@ Policy, agent_policies, ) -from ..namespace import get_namespace_key from ..services.agent_names import normalize_agent_name_or_422 from ..services.controls import ( AgentControlEnabledState, @@ -112,7 +111,7 @@ def _validate_controls_for_agent(agent: Agent, controls: list[Control]) -> list[ agent_evaluators = {e.name: e for e in (agent_data.evaluators or [])} for control in controls: - # Skip unrendered template controls — they have no evaluators to validate. + # Skip unrendered template controls - they have no evaluators to validate. if ( isinstance(control.data, dict) and control.data.get("template") is not None @@ -286,7 +285,7 @@ async def list_agents( limit: int = _DEFAULT_PAGINATION_LIMIT, name: str | None = None, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> ListAgentsResponse: """ List all registered agents with cursor-based pagination. @@ -300,11 +299,13 @@ async def list_agents( limit: Pagination limit (default 20, max 100) name: Optional name filter (case-insensitive partial match) db: Database session (injected) - namespace_key: Resolved namespace for the request + principal: Authorized request principal Returns: ListAgentsResponse with agent summaries and pagination info """ + namespace_key = principal.namespace_key + # Clamp limit limit = min(max(1, limit), _MAX_PAGINATION_LIMIT) @@ -377,14 +378,20 @@ async def list_agents( agent_policies.c.agent_name, agent_policies.c.policy_id, ) - .where(agent_policies.c.agent_name.in_(agent_names)) + .where( + agent_policies.c.namespace_key == namespace_key, + agent_policies.c.agent_name.in_(agent_names), + ) .order_by(agent_policies.c.agent_name, agent_policies.c.policy_id) ) policy_ids_result = await db.execute(policy_ids_query) for assoc_agent_name, policy_id in policy_ids_result.all(): policy_ids_map.setdefault(assoc_agent_name, []).append(policy_id) - control_counts_map = await control_service.list_active_control_counts_by_agent(agent_names) + control_counts_map = await control_service.list_active_control_counts_by_agent( + agent_names, + namespace_key=namespace_key, + ) # Build summaries summaries: list[AgentSummary] = [] @@ -436,9 +443,8 @@ async def list_agents( ) async def init_agent( request: InitAgentRequest, - client: RequireAPIKey, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_CREATE)), ) -> InitAgentResponse: """ Register a new agent or update an existing agent's steps and metadata. @@ -462,10 +468,13 @@ async def init_agent( Args: request: Agent metadata and step schemas db: Database session (injected) + principal: Authorized request principal Returns: InitAgentResponse with created flag and the effective controls """ + namespace_key = principal.namespace_key + # Check for evaluator name collisions with built-in evaluators builtin_names = _get_builtin_evaluator_names() for ev in request.evaluators: @@ -835,7 +844,7 @@ async def init_agent( async def get_agent( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetAgentResponse: """ Retrieve agent metadata and all registered steps. @@ -845,8 +854,7 @@ async def get_agent( Args: agent_name: Agent identifier db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: GetAgentResponse with agent metadata and step list @@ -855,6 +863,7 @@ async def get_agent( HTTPException 404: Agent not found HTTPException 422: Agent data is corrupted """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where(Agent.name == agent_name, Agent.namespace_key == namespace_key) @@ -917,7 +926,7 @@ async def _get_agent_or_404( The lookup is always namespace-scoped: an agent that exists only in another namespace surfaces as 404 (non-disclosing) so duplicate - names across namespaces — which the schema explicitly permits — + names across namespaces - which the schema explicitly permits - cannot be addressed across the namespace boundary. """ normalized_agent_name = normalize_agent_name_or_422(agent_name) @@ -940,7 +949,6 @@ async def _get_agent_or_404( @router.post( "/{agent_name}/policies/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Associate policy with agent", response_description="Success confirmation", @@ -949,9 +957,10 @@ async def add_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Associate a policy with an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1017,7 +1026,6 @@ async def add_agent_policy( @router.post( "/{agent_name}/policy/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=SetPolicyResponse, summary="Assign policy to agent (compatibility)", response_description="Success status with previous policy ID", @@ -1026,9 +1034,10 @@ async def set_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> SetPolicyResponse: """Compatibility endpoint that replaces all policy associations with one policy.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1117,9 +1126,10 @@ async def set_agent_policy( async def get_agent_policies( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetAgentPoliciesResponse: """List policy IDs associated with an agent.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) result = await db.execute( select(agent_policies.c.policy_id) @@ -1141,9 +1151,10 @@ async def get_agent_policies( async def get_agent_policy( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetPolicyResponse: """Compatibility endpoint that returns the first associated policy.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( select(Policy.id) @@ -1172,7 +1183,6 @@ async def get_agent_policy( @router.delete( "/{agent_name}/policies/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove policy association from agent", response_description="Success confirmation", @@ -1181,13 +1191,14 @@ async def remove_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Remove a policy association from an agent. Idempotent for existing resources: removing a non-associated link is a no-op. Missing agent/policy resources still return 404. """ + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1230,7 +1241,6 @@ async def remove_agent_policy( @router.delete( "/{agent_name}/policies", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove all policy associations from agent", response_description="Success confirmation", @@ -1238,9 +1248,10 @@ async def remove_agent_policy( async def remove_all_agent_policies( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Remove all policy associations from an agent.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) try: @@ -1271,7 +1282,6 @@ async def remove_all_agent_policies( @router.delete( "/{agent_name}/policy", - dependencies=[Depends(require_admin_key)], response_model=DeletePolicyResponse, summary="Remove agent's policy assignment (compatibility)", response_description="Success confirmation", @@ -1279,9 +1289,10 @@ async def remove_all_agent_policies( async def delete_agent_policy( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> DeletePolicyResponse: """Compatibility endpoint that removes all policy associations.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) existing_policy_result = await db.execute( @@ -1328,7 +1339,6 @@ async def delete_agent_policy( @router.post( "/{agent_name}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Associate control directly with agent", response_description="Success confirmation", @@ -1337,9 +1347,10 @@ async def add_agent_control( agent_name: str, control_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Associate a control directly with an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) control_service = ControlService(db) control = await control_service.get_active_control_or_404( @@ -1389,7 +1400,6 @@ async def add_agent_control( @router.delete( "/{agent_name}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=RemoveAgentControlResponse, summary="Remove direct control association from agent", response_description="Success confirmation", @@ -1398,9 +1408,10 @@ async def remove_agent_control( agent_name: str, control_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> RemoveAgentControlResponse: """Remove a direct control association from an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) control_service = ControlService(db) await control_service.get_active_control_or_404(control_id, namespace_key=namespace_key) @@ -1481,7 +1492,7 @@ async def list_agent_controls( description="Optional opaque target identifier. Required when target_type is supplied.", ), db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> AgentControlsResponse: """ List protection controls effective for an agent. @@ -1506,7 +1517,7 @@ async def list_agent_controls( target_type: Optional opaque target kind (paired with target_id) target_id: Optional opaque target identifier (paired with target_type) db: Database session (injected) - namespace_key: Namespace scoping for the resolution (injected) + principal: Authorized request principal Returns: AgentControlsResponse with controls matching the requested state filters @@ -1515,6 +1526,8 @@ async def list_agent_controls( HTTPException 400: target_type and target_id were not supplied together HTTPException 404: Agent not found """ + namespace_key = principal.namespace_key + if (target_type is None) != (target_id is None): raise BadRequestError( error_code=ErrorCode.VALIDATION_ERROR, @@ -1572,7 +1585,7 @@ async def list_agent_evaluators( cursor: str | None = None, limit: int = _DEFAULT_PAGINATION_LIMIT, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> ListEvaluatorsResponse: """ List all evaluator schemas registered with an agent. @@ -1586,8 +1599,7 @@ async def list_agent_evaluators( cursor: Optional cursor for pagination (name of last evaluator from previous page) limit: Pagination limit (default 20, max 100) db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: ListEvaluatorsResponse with evaluator schemas and pagination @@ -1595,6 +1607,7 @@ async def list_agent_evaluators( Raises: HTTPException 404: Agent not found """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) # Clamp limit limit = min(max(1, limit), _MAX_PAGINATION_LIMIT) @@ -1672,7 +1685,7 @@ async def get_agent_evaluator( agent_name: str, evaluator_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> EvaluatorSchemaItem: """ Get a specific evaluator schema registered with an agent. @@ -1681,8 +1694,7 @@ async def get_agent_evaluator( agent_name: Agent identifier evaluator_name: Name of the evaluator db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: EvaluatorSchemaItem with schema details @@ -1690,6 +1702,7 @@ async def get_agent_evaluator( Raises: HTTPException 404: Agent or evaluator not found """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where(Agent.name == agent_name, Agent.namespace_key == namespace_key) @@ -1734,7 +1747,6 @@ async def get_agent_evaluator( @router.patch( "/{agent_name}", - dependencies=[Depends(require_admin_key)], response_model=PatchAgentResponse, summary="Modify agent (remove steps/evaluators)", response_description="Lists of removed items", @@ -1743,7 +1755,7 @@ async def patch_agent( agent_name: str, request: PatchAgentRequest, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> PatchAgentResponse: """ Remove steps and/or evaluators from an agent. @@ -1755,6 +1767,7 @@ async def patch_agent( agent_name: Agent identifier request: Lists of step/evaluator identifiers to remove db: Database session (injected) + principal: Authorized request principal Returns: PatchAgentResponse with lists of actually removed items @@ -1763,6 +1776,7 @@ async def patch_agent( HTTPException 404: Agent not found HTTPException 500: Database error during update """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where( diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index 1a23baa8..f80cd2fa 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -2,9 +2,8 @@ The runtime auth flow is two-phase: this endpoint is phase one. The caller presents a long-lived credential plus ``(target_type, -target_id)``; the default authorizer (typically -:class:`HttpUpstreamAuthProvider` in production) authenticates the -credential and authorizes the implied +target_id)``; the default authorizer authenticates the credential and +authorizes the implied :data:`Operation.RUNTIME_TOKEN_EXCHANGE`. On success, this endpoint mints a short-lived local runtime token bound to the supplied target and returns it. Subsequent target-bearing runtime calls present the @@ -130,8 +129,8 @@ async def runtime_token_exchange( actor_id = principal.caller_id or "anonymous" # The exchange endpoint requires the authorizer to explicitly grant - # runtime.use. Providers that do not surface scopes (legacy local - # provider) supply a normalized grant for ``RUNTIME_TOKEN_EXCHANGE``; + # runtime.use. Local providers supply a normalized grant for + # ``RUNTIME_TOKEN_EXCHANGE``; # upstream providers that return an explicit empty scopes array fail # closed here rather than escalating to runtime.use. if Operation.RUNTIME_USE.value not in principal.scopes: @@ -155,7 +154,7 @@ async def runtime_token_exchange( ) except UpstreamGrantExpiredError as exc: # Upstream returned a grant whose ``expires_at`` is already in - # the past — minting would hand the caller a token that's dead + # the past - minting would hand the caller a token that's dead # on arrival. Distinguished from the misconfigured case so the # error code and status reflect "upstream returned bad data." raise APIError( diff --git a/server/src/agent_control_server/endpoints/control_bindings.py b/server/src/agent_control_server/endpoints/control_bindings.py index 92798ae1..d2fe4b44 100644 --- a/server/src/agent_control_server/endpoints/control_bindings.py +++ b/server/src/agent_control_server/endpoints/control_bindings.py @@ -26,7 +26,6 @@ from ..db import get_async_db from ..errors import BadRequestError from ..models import ControlBinding -from ..namespace import get_namespace_key from ..services.control_bindings import ControlBindingsService router = APIRouter(prefix="/control-bindings", tags=["control-bindings"]) @@ -94,26 +93,21 @@ def _to_response(binding: ControlBinding) -> GetControlBindingResponse: async def create_control_binding( request: CreateControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> CreateControlBindingResponse: """Attach a control to an opaque external target. Each binding row is scoped to the request namespace as resolved by - ``get_namespace_key``. The auth chain still runs via - ``require_operation`` for authentication and authorization, but the - storage namespace is taken from the same resolver the rest of the - server uses so binding writes and runtime reads stay in lockstep - until auth-derived namespace resolution lands across every endpoint. + the active authorizer. """ service = ControlBindingsService(db) binding = await service.create_binding( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, @@ -148,20 +142,18 @@ async def list_control_bindings( target_id: str | None = None, control_id: int | None = None, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_READ, context_builder=_binding_list_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> ListControlBindingsResponse: """Return bindings in the request namespace with optional filters and cursor-based pagination. Bindings are ordered by ID descending (newest first). The cursor is opaque to clients: pass back the ``next_cursor`` value verbatim to fetch the following page. The - storage namespace is resolved by ``get_namespace_key`` so this - listing stays in lockstep with the rest of the server's reads. + storage namespace is resolved by the active authorizer. """ parsed_cursor: int | None if cursor is None: @@ -177,7 +169,7 @@ async def list_control_bindings( ) from exc service = ControlBindingsService(db) page = await service.list_bindings( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, cursor=parsed_cursor, limit=limit, target_type=target_type, @@ -204,8 +196,7 @@ async def list_control_bindings( async def get_control_binding( binding_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_READ)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_READ)), ) -> GetControlBindingResponse: """Read a single control binding by surrogate ID. @@ -218,7 +209,9 @@ async def get_control_binding( of which forward ``(target_type, target_id)`` to the authorizer. """ service = ControlBindingsService(db) - binding = await service.get_binding_or_404(namespace_key=namespace_key, binding_id=binding_id) + binding = await service.get_binding_or_404( + namespace_key=principal.namespace_key, binding_id=binding_id + ) return _to_response(binding) @@ -232,8 +225,7 @@ async def patch_control_binding( binding_id: int, request: PatchControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), ) -> PatchControlBindingResponse: """Update the ``enabled`` flag on a control binding. @@ -244,7 +236,7 @@ async def patch_control_binding( """ service = ControlBindingsService(db) binding = await service.set_enabled( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, binding_id=binding_id, enabled=request.enabled, ) @@ -261,8 +253,7 @@ async def patch_control_binding( async def delete_control_binding( binding_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), ) -> DeleteControlBindingResponse: """Delete a control binding by surrogate ID. @@ -272,7 +263,7 @@ async def delete_control_binding( target-scoped detach that forwards the target to the authorizer. """ service = ControlBindingsService(db) - await service.delete_binding(namespace_key=namespace_key, binding_id=binding_id) + await service.delete_binding(namespace_key=principal.namespace_key, binding_id=binding_id) await db.commit() return DeleteControlBindingResponse(success=True) @@ -286,13 +277,12 @@ async def delete_control_binding( async def upsert_control_binding_by_key( request: UpsertControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> UpsertControlBindingResponse: """Idempotent attach using ``(target_type, target_id, control_id)`` as the natural key. Updates ``enabled`` on an existing match; creates a new row @@ -300,7 +290,7 @@ async def upsert_control_binding_by_key( """ service = ControlBindingsService(db) binding, created = await service.upsert_by_natural_key( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, @@ -324,20 +314,19 @@ async def upsert_control_binding_by_key( async def delete_control_binding_by_key( request: DeleteControlBindingByKeyRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> DeleteControlBindingByKeyResponse: """Idempotent detach by natural key. Returns ``deleted=False`` when no matching binding exists. """ service = ControlBindingsService(db) deleted = await service.delete_by_natural_key( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index fcb7cb18..5b01593c 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -229,7 +229,7 @@ async def _materialize_control_input( enabled=enabled, ) - # Incomplete values — only allowed for new controls or already-unrendered + # Incomplete values - only allowed for new controls or already-unrendered # templates. Updating a rendered control with incomplete values is # rejected to prevent silently stripping rendered fields. current_is_rendered = ( @@ -470,7 +470,7 @@ async def render_control_template( async def create_control( request: CreateControlRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> CreateControlResponse: """ Create a new control with a unique name. @@ -492,7 +492,10 @@ async def create_control( control_service = ControlService(db) # Uniqueness check - if await control_service.active_control_name_exists(request.name): + namespace_key = principal.namespace_key + if await control_service.active_control_name_exists( + request.name, namespace_key=namespace_key + ): raise ConflictError( error_code=ErrorCode.CONTROL_NAME_CONFLICT, detail=f"Control with name '{request.name}' already exists", @@ -504,7 +507,11 @@ async def create_control( control_def = await _materialize_control_input(request.data, db=db) control_data = _serialize_control_data(control_def) - control = control_service.create_control(name=request.name, data=control_data) + control = control_service.create_control( + namespace_key=namespace_key, + name=request.name, + data=control_data, + ) try: await control_service.create_version( control, @@ -569,7 +576,7 @@ async def get_control_schema() -> GetControlSchemaResponse: async def get_control( control_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlResponse: """ Retrieve a control by ID including its name and configuration data. @@ -584,7 +591,9 @@ async def get_control( Raises: HTTPException 404: Control not found """ - control = await ControlService(db).get_active_control_or_404(control_id) + control = await ControlService(db).get_active_control_or_404( + control_id, namespace_key=principal.namespace_key + ) control_data = _parse_stored_control_data( control.data, control_name=control.name, @@ -608,7 +617,7 @@ async def get_control( async def get_control_data( control_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlDataResponse: """ Retrieve the configuration data for a control. @@ -626,7 +635,9 @@ async def get_control_data( HTTPException 404: Control not found HTTPException 422: Control data is corrupted """ - control = await ControlService(db).get_active_control_or_404(control_id) + control = await ControlService(db).get_active_control_or_404( + control_id, namespace_key=principal.namespace_key + ) control_data = _parse_stored_control_data( control.data, control_name=control.name, @@ -648,10 +659,15 @@ async def list_control_versions( ), limit: int = Query(_DEFAULT_PAGINATION_LIMIT, ge=1, le=_MAX_PAGINATION_LIMIT), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlVersionsResponse: """List control versions ordered newest-first using cursor-based pagination.""" - page = await ControlService(db).list_versions(control_id, cursor=cursor, limit=limit) + page = await ControlService(db).list_versions( + control_id, + namespace_key=principal.namespace_key, + cursor=cursor, + limit=limit, + ) return ListControlVersionsResponse( versions=[ @@ -682,10 +698,12 @@ async def get_control_version( control_id: int, version_num: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlVersionResponse: """Return a specific control version, including its raw persisted snapshot.""" - version = await ControlService(db).get_version_or_404(control_id, version_num) + version = await ControlService(db).get_version_or_404( + control_id, version_num, namespace_key=principal.namespace_key + ) return GetControlVersionResponse( version_num=version.version_num, event_type=version.event_type, @@ -705,7 +723,7 @@ async def set_control_data( control_id: int, request: SetControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> SetControlDataResponse: """ Update the configuration data for a control. @@ -726,7 +744,9 @@ async def set_control_data( HTTPException 500: Database error during update """ control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=principal.namespace_key, for_update=True + ) control_def = await _materialize_control_input( request.data, @@ -767,11 +787,12 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) -# Validation uses the authoring path, so require create access. +# Authorized as CONTROLS_READ: validate exercises the materialization +# path but does not mutate stored control data. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. @@ -811,7 +832,7 @@ async def list_controls( execution: str | None = Query(None, description="Filter by execution ('server' or 'sdk')"), tag: str | None = Query(None, description="Filter by tag"), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlsResponse: """ List all controls with optional filtering and cursor-based pagination. @@ -837,7 +858,9 @@ async def list_controls( GET /controls?limit=10&enabled=true&step_type=tool """ control_service = ControlService(db) + namespace_key = principal.namespace_key page = await control_service.list_controls_page( + namespace_key=namespace_key, cursor=cursor, limit=limit, name=name, @@ -849,7 +872,8 @@ async def list_controls( tag=tag, ) usage_by_control_id = await control_service.list_control_usage( - [control.id for control in page.controls] + [control.id for control in page.controls], + namespace_key=namespace_key, ) # Build summaries (filtering already done at DB level) @@ -910,7 +934,7 @@ async def delete_control( "If false, fail if control is associated with any policy or agent.", ), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_DELETE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_DELETE)), ) -> DeleteControlResponse: """ Delete a control by ID. @@ -933,13 +957,18 @@ async def delete_control( """ control_service = ControlService(db) bindings_service = ControlBindingsService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + namespace_key = principal.namespace_key + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key, for_update=True + ) - associations = await control_service.list_control_associations(control_id) + associations = await control_service.list_control_associations( + control_id, namespace_key=namespace_key + ) associated_policy_ids = associations.policy_ids associated_agent_names = associations.agent_names target_binding_ids = await bindings_service.list_binding_ids_for_control( - namespace_key=control.namespace_key, control_id=control_id + namespace_key=namespace_key, control_id=control_id ) if ( @@ -996,13 +1025,15 @@ async def delete_control( dissociated_from_policies: list[int] = [] dissociated_from_agents: list[str] = [] if associated_policy_ids or associated_agent_names: - dissociated = await control_service.remove_all_control_associations(control_id) + dissociated = await control_service.remove_all_control_associations( + control_id, namespace_key=namespace_key + ) dissociated_from_policies = dissociated.policy_ids dissociated_from_agents = dissociated.agent_names detached_target_bindings: list[int] = [] if target_binding_ids: detached_target_bindings = await bindings_service.delete_bindings_for_control( - namespace_key=control.namespace_key, control_id=control_id + namespace_key=namespace_key, control_id=control_id ) if dissociated_from_policies or dissociated_from_agents or detached_target_bindings: _logger.info( @@ -1057,7 +1088,7 @@ async def patch_control( control_id: int, request: PatchControlRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> PatchControlResponse: """ Update control metadata (name and/or enabled status). @@ -1081,7 +1112,10 @@ async def patch_control( HTTPException 500: Database error during update """ control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + namespace_key = principal.namespace_key + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key, for_update=True + ) parsed_control = _parse_stored_control_data( control.data, control_name=control.name, @@ -1096,6 +1130,7 @@ async def patch_control( # Check for name collision if await control_service.active_control_name_exists( request.name, + namespace_key=namespace_key, exclude_control_id=control_id, ): raise ConflictError( diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index e018796e..437af8b5 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -10,16 +10,15 @@ EvaluationResponse, ) from agent_control_models.errors import ErrorCode, ValidationErrorItem -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import RequireAPIKey +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import APIValidationError, NotFoundError from ..logging_utils import get_logger from ..models import Agent -from ..namespace import get_namespace_key from ..services.controls import ControlService router = APIRouter(prefix="/evaluation", tags=["evaluation"]) @@ -118,6 +117,20 @@ def _sanitize_evaluation_response(response: EvaluationResponse) -> EvaluationRes ) +async def _evaluation_context(request: Request) -> dict[str, object]: + """Surface target identifiers to the runtime authorizer.""" + try: + body = await request.json() + except Exception: # noqa: BLE001 malformed JSON, defer to endpoint validation + return {} + if not isinstance(body, dict): + return {} + return { + "target_type": body.get("target_type"), + "target_id": body.get("target_id"), + } + + @router.post( "", response_model=EvaluationResponse, @@ -126,9 +139,10 @@ def _sanitize_evaluation_response(response: EvaluationResponse) -> EvaluationRes ) async def evaluate( request: EvaluationRequest, - client: RequireAPIKey, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends( + require_operation(Operation.RUNTIME_USE, context_builder=_evaluation_context) + ), ) -> EvaluationResponse: """Analyze content for safety and control violations. @@ -144,7 +158,7 @@ async def evaluate( on the server; SDKs reconstruct and emit those events separately through the observability ingestion endpoint. """ - del client # Authentication is still required by dependency injection. + namespace_key = principal.namespace_key agent_result = await db.execute( select(Agent).where( diff --git a/server/src/agent_control_server/endpoints/policies.py b/server/src/agent_control_server/endpoints/policies.py index 7b8b2ef9..ddda7127 100644 --- a/server/src/agent_control_server/endpoints/policies.py +++ b/server/src/agent_control_server/endpoints/policies.py @@ -9,7 +9,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import require_admin_key +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import ConflictError, DatabaseError, NotFoundError from ..logging_utils import get_logger @@ -23,13 +23,14 @@ @router.put( "", - dependencies=[Depends(require_admin_key)], response_model=CreatePolicyResponse, summary="Create a new policy", response_description="Created policy ID", ) async def create_policy( - request: CreatePolicyRequest, db: AsyncSession = Depends(get_async_db) + request: CreatePolicyRequest, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_CREATE)), ) -> CreatePolicyResponse: """ Create a new empty policy with a unique name. @@ -48,8 +49,14 @@ async def create_policy( HTTPException 409: Policy with this name already exists HTTPException 500: Database error during creation """ + namespace_key = principal.namespace_key # Uniqueness check - existing = await db.execute(select(Policy.id).where(Policy.name == request.name)) + existing = await db.execute( + select(Policy.id).where( + Policy.namespace_key == namespace_key, + Policy.name == request.name, + ) + ) if existing.first() is not None: raise ConflictError( error_code=ErrorCode.POLICY_NAME_CONFLICT, @@ -59,7 +66,7 @@ async def create_policy( hint="Choose a different name or update the existing policy.", ) - policy = Policy(name=request.name) + policy = Policy(namespace_key=namespace_key, name=request.name) db.add(policy) try: await db.commit() @@ -80,13 +87,15 @@ async def create_policy( @router.post( "/{policy_id}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Add control to policy", response_description="Success confirmation", ) async def add_control_to_policy( - policy_id: int, control_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + control_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_UPDATE)), ) -> AssocResponse: """ Associate a control with a policy. @@ -106,8 +115,14 @@ async def add_control_to_policy( HTTPException 404: Policy or control not found HTTPException 500: Database error """ + namespace_key = principal.namespace_key # Find policy and control - pol_res = await db.execute(select(Policy).where(Policy.id == policy_id)) + pol_res = await db.execute( + select(Policy).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) policy = pol_res.scalars().first() if policy is None: raise NotFoundError( @@ -119,11 +134,17 @@ async def add_control_to_policy( ) control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key + ) # Add association using INSERT ... ON CONFLICT DO NOTHING for idempotency try: - await control_service.add_control_to_policy(policy_id=policy_id, control_id=control_id) + await control_service.add_control_to_policy( + policy_id=policy_id, + control_id=control_id, + namespace_key=namespace_key, + ) await db.commit() except Exception: await db.rollback() @@ -149,13 +170,15 @@ async def add_control_to_policy( @router.delete( "/{policy_id}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove control from policy", response_description="Success confirmation", ) async def remove_control_from_policy( - policy_id: int, control_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + control_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_UPDATE)), ) -> AssocResponse: """ Remove a control from a policy. @@ -175,7 +198,13 @@ async def remove_control_from_policy( HTTPException 404: Policy or control not found HTTPException 500: Database error """ - pol_res = await db.execute(select(Policy).where(Policy.id == policy_id)) + namespace_key = principal.namespace_key + pol_res = await db.execute( + select(Policy).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) policy = pol_res.scalars().first() if policy is None: raise NotFoundError( @@ -187,13 +216,16 @@ async def remove_control_from_policy( ) control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key + ) # Remove association (idempotent - deleting non-existent is no-op) try: await control_service.remove_control_from_policy( policy_id=policy_id, control_id=control_id, + namespace_key=namespace_key, ) await db.commit() except Exception: @@ -222,7 +254,9 @@ async def remove_control_from_policy( response_description="List of control IDs", ) async def list_policy_controls( - policy_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_READ)), ) -> GetPolicyControlsResponse: """ List all controls associated with a policy. @@ -237,7 +271,13 @@ async def list_policy_controls( Raises: HTTPException 404: Policy not found """ - pol_res = await db.execute(select(Policy.id).where(Policy.id == policy_id)) + namespace_key = principal.namespace_key + pol_res = await db.execute( + select(Policy.id).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) if pol_res.first() is None: raise NotFoundError( error_code=ErrorCode.POLICY_NOT_FOUND, @@ -247,5 +287,8 @@ async def list_policy_controls( hint="Verify the policy ID is correct and the policy has been created.", ) - control_ids = await ControlService(db).list_policy_control_ids(policy_id) + control_ids = await ControlService(db).list_policy_control_ids( + policy_id, + namespace_key=namespace_key, + ) return GetPolicyControlsResponse(control_ids=control_ids) diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index c386cf22..ddd22195 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -252,7 +252,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- # Register handler for FastAPI's RequestValidationError (Pydantic validation) app.add_exception_handler(RequestValidationError, validation_exception_handler) # type: ignore[arg-type] -# Register handler for standard HTTPException (legacy code, FastAPI internals) +# Register handler for standard HTTPException (older routes, FastAPI internals) app.add_exception_handler(HTTPException, http_exception_handler) # type: ignore[arg-type] # Register catch-all handler for unexpected exceptions @@ -261,16 +261,18 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- # API v1 prefix for all routes api_v1_prefix = f"{settings.api_prefix}/{settings.api_version}" -# Protected routes (require valid API key) +# API routers. Routers migrated to the auth framework mount the +# non-validating header extractor only so OpenAPI advertises X-API-Key; +# each endpoint's ``require_operation`` dependency owns authn + authz. app.include_router( agent_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( policy_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( # Endpoint dependencies handle auth; this advertises X-API-Key. @@ -281,11 +283,11 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- app.include_router( # The auth framework on each endpoint owns authentication and # authorization for control bindings, so this router is mounted - # without the legacy router-level gate. See ``auth_framework`` for + # without the router-level auth gate. See ``auth_framework`` for # the provider contract. ``get_api_key_from_header`` is a non- # validating extractor (``auto_error=False``); it is attached purely # so the generated OpenAPI spec advertises the X-API-Key requirement - # on these routes — without it, downstream SDK generators would treat + # on these routes - without it, downstream SDK generators would treat # the routes as unauthenticated. control_binding_router, prefix=api_v1_prefix, @@ -309,9 +311,10 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- app.include_router( evaluation_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) +# Evaluator discovery still uses the local credential dependency. app.include_router( evaluator_router, prefix=api_v1_prefix, @@ -324,7 +327,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- prefix=api_v1_prefix, ) -# System routes (config, login, logout) — no auth required +# System routes (config, login, logout) - no auth required app.include_router( system_router, prefix=settings.api_prefix, diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index 263120b7..41a62282 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -20,6 +20,7 @@ from ..errors import APIValidationError, NotFoundError from ..models import ( + DEFAULT_NAMESPACE_KEY, Control, ControlBinding, ControlVersion, @@ -96,9 +97,15 @@ class ControlService: def __init__(self, db: AsyncSession) -> None: self._db = db - def create_control(self, *, name: str, data: dict[str, Any]) -> Control: + def create_control( + self, + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, + name: str, + data: dict[str, Any], + ) -> Control: """Create a new pending control row.""" - control = Control(name=name, data=data) + control = Control(namespace_key=namespace_key, name=name, data=data) self._db.add(control) return control @@ -128,10 +135,13 @@ async def get_control_or_404( self, control_id: int, *, + namespace_key: str | None = None, for_update: bool = False, ) -> Control: """Load any control row, including soft-deleted controls.""" stmt = select(Control).where(Control.id == control_id) + if namespace_key is not None: + stmt = stmt.where(Control.namespace_key == namespace_key) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) @@ -180,10 +190,15 @@ async def active_control_name_exists( self, name: str, *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, exclude_control_id: int | None = None, ) -> bool: """Return whether an active control already uses the provided name.""" - stmt = select(Control.id).where(Control.name == name, Control.deleted_at.is_(None)) + stmt = select(Control.id).where( + Control.namespace_key == namespace_key, + Control.name == name, + Control.deleted_at.is_(None), + ) if exclude_control_id is not None: stmt = stmt.where(Control.id != exclude_control_id) result = await self._db.execute(stmt) @@ -216,11 +231,12 @@ async def list_versions( self, control_id: int, *, + namespace_key: str, cursor: int | None, limit: int, ) -> ControlVersionPage: """Return control versions newest-first with cursor pagination.""" - await self.get_control_or_404(control_id) + await self.get_control_or_404(control_id, namespace_key=namespace_key) total_result = await self._db.execute( select(func.count()) @@ -255,9 +271,11 @@ async def list_versions( next_cursor=next_cursor, ) - async def get_version_or_404(self, control_id: int, version_num: int) -> ControlVersion: + async def get_version_or_404( + self, control_id: int, version_num: int, *, namespace_key: str + ) -> ControlVersion: """Load a specific version row for a control.""" - await self.get_control_or_404(control_id) + await self.get_control_or_404(control_id, namespace_key=namespace_key) result = await self._db.execute( select(ControlVersion).where( @@ -303,12 +321,17 @@ async def list_controls_for_policy( result = await self._db.execute(stmt) return list(result.scalars().unique().all()) - async def list_policy_control_ids(self, policy_id: int) -> list[int]: + async def list_policy_control_ids(self, policy_id: int, *, namespace_key: str) -> list[int]: """Return active control IDs directly associated with a policy.""" result = await self._db.execute( select(policy_controls.c.control_id) .join(Control, Control.id == policy_controls.c.control_id) - .where(policy_controls.c.policy_id == policy_id, Control.deleted_at.is_(None)) + .where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.policy_id == policy_id, + Control.namespace_key == namespace_key, + Control.deleted_at.is_(None), + ) .order_by(policy_controls.c.control_id) ) return [cast(int, row[0]) for row in result.all()] @@ -396,6 +419,7 @@ async def list_runtime_controls_for_agent( async def list_controls_page( self, *, + namespace_key: str, cursor: int | None, limit: int, name: str | None, @@ -407,7 +431,11 @@ async def list_controls_page( tag: str | None, ) -> ControlListPage: """Return paginated active controls for the browse endpoint.""" - query = select(Control).where(Control.deleted_at.is_(None)).order_by(Control.id.desc()) + query = ( + select(Control) + .where(Control.namespace_key == namespace_key, Control.deleted_at.is_(None)) + .order_by(Control.id.desc()) + ) query = self._apply_control_list_filters( query, name=name, @@ -424,7 +452,11 @@ async def list_controls_page( result = await self._db.execute(query.limit(limit + 1)) controls = list(result.scalars().all()) - total_query = select(func.count()).select_from(Control).where(Control.deleted_at.is_(None)) + total_query = ( + select(func.count()) + .select_from(Control) + .where(Control.namespace_key == namespace_key, Control.deleted_at.is_(None)) + ) total_query = self._apply_control_list_filters( total_query, name=name, @@ -453,7 +485,9 @@ async def list_controls_page( next_cursor=next_cursor, ) - async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, ControlUsage]: + async def list_control_usage( + self, control_ids: Sequence[int], *, namespace_key: str + ) -> dict[int, ControlUsage]: """Return representative agent usage and usage counts for the provided controls.""" if not control_ids: return {} @@ -465,8 +499,16 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont agent_policies.c.agent_name, ) .select_from(policy_controls) - .join(agent_policies, policy_controls.c.policy_id == agent_policies.c.policy_id) - .where(policy_controls.c.control_id.in_(control_ids)) + .join( + agent_policies, + (policy_controls.c.policy_id == agent_policies.c.policy_id) + & (policy_controls.c.namespace_key == agent_policies.c.namespace_key), + ) + .where( + policy_controls.c.namespace_key == namespace_key, + agent_policies.c.namespace_key == namespace_key, + policy_controls.c.control_id.in_(control_ids), + ) ) direct_agents_query = ( select( @@ -474,7 +516,10 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont agent_controls.c.agent_name, ) .select_from(agent_controls) - .where(agent_controls.c.control_id.in_(control_ids)) + .where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id.in_(control_ids), + ) ) agents_result = await self._db.execute(union_all(policy_agents_query, direct_agents_query)) for control_id, agent_name in agents_result.all(): @@ -491,6 +536,8 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont async def list_active_control_counts_by_agent( self, agent_names: Sequence[str], + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, ) -> dict[str, int]: """Return active control counts keyed by agent name.""" if not agent_names: @@ -503,15 +550,24 @@ async def list_active_control_counts_by_agent( ) .select_from( agent_policies.join( - policy_controls, agent_policies.c.policy_id == policy_controls.c.policy_id + policy_controls, + (agent_policies.c.policy_id == policy_controls.c.policy_id) + & (agent_policies.c.namespace_key == policy_controls.c.namespace_key), ) ) - .where(agent_policies.c.agent_name.in_(agent_names)) + .where( + agent_policies.c.namespace_key == namespace_key, + policy_controls.c.namespace_key == namespace_key, + agent_policies.c.agent_name.in_(agent_names), + ) ) direct_associations = select( agent_controls.c.agent_name.label("agent_name"), agent_controls.c.control_id.label("control_id"), - ).where(agent_controls.c.agent_name.in_(agent_names)) + ).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.agent_name.in_(agent_names), + ) all_associations = union_all(policy_associations, direct_associations).subquery() result = await self._db.execute( @@ -521,6 +577,7 @@ async def list_active_control_counts_by_agent( ) .join(Control, all_associations.c.control_id == Control.id) .where( + Control.namespace_key == namespace_key, Control.deleted_at.is_(None), or_( Control.data["enabled"].astext == "true", @@ -531,19 +588,28 @@ async def list_active_control_counts_by_agent( ) return {cast(str, row[0]): cast(int, row[1]) for row in result.all()} - async def add_control_to_policy(self, *, policy_id: int, control_id: int) -> None: + async def add_control_to_policy( + self, *, policy_id: int, control_id: int, namespace_key: str + ) -> None: """Create a policy-control association if it does not already exist.""" await self._db.execute( pg_insert(policy_controls) - .values(policy_id=policy_id, control_id=control_id) + .values( + namespace_key=namespace_key, + policy_id=policy_id, + control_id=control_id, + ) .on_conflict_do_nothing() ) - async def remove_control_from_policy(self, *, policy_id: int, control_id: int) -> None: + async def remove_control_from_policy( + self, *, policy_id: int, control_id: int, namespace_key: str + ) -> None: """Remove a policy-control association if it exists.""" await self._db.execute( delete(policy_controls).where( - (policy_controls.c.policy_id == policy_id) + (policy_controls.c.namespace_key == namespace_key) + & (policy_controls.c.policy_id == policy_id) & (policy_controls.c.control_id == control_id) ) ) @@ -613,16 +679,24 @@ async def remove_control_from_agent( control_still_active=policy_inheritance_result.first() is not None, ) - async def list_control_associations(self, control_id: int) -> ControlAssociations: + async def list_control_associations( + self, control_id: int, *, namespace_key: str + ) -> ControlAssociations: """Return all policy and direct agent associations for a control.""" policy_assoc_query = select( policy_controls.c.policy_id.label("policy_id"), literal(None, type_=String).label("agent_name"), - ).where(policy_controls.c.control_id == control_id) + ).where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.control_id == control_id, + ) agent_assoc_query = select( literal(None, type_=Integer).label("policy_id"), agent_controls.c.agent_name.label("agent_name"), - ).where(agent_controls.c.control_id == control_id) + ).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id == control_id, + ) assoc_result = await self._db.execute(union_all(policy_assoc_query, agent_assoc_query)) policy_ids: set[int] = set() @@ -638,16 +712,26 @@ async def list_control_associations(self, control_id: int) -> ControlAssociation agent_names=sorted(agent_names), ) - async def remove_all_control_associations(self, control_id: int) -> ControlAssociations: + async def remove_all_control_associations( + self, control_id: int, *, namespace_key: str + ) -> ControlAssociations: """Remove all policy and direct agent associations for a control.""" - associations = await self.list_control_associations(control_id) + associations = await self.list_control_associations( + control_id, namespace_key=namespace_key + ) if associations.policy_ids: await self._db.execute( - delete(policy_controls).where(policy_controls.c.control_id == control_id) + delete(policy_controls).where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.control_id == control_id, + ) ) if associations.agent_names: await self._db.execute( - delete(agent_controls).where(agent_controls.c.control_id == control_id) + delete(agent_controls).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id == control_id, + ) ) return associations diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 96c4aad8..2d39bfa3 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -20,6 +20,7 @@ AccessLevel, HeaderAuthProvider, HttpUpstreamAuthProvider, + NoAuthProvider, ) from agent_control_server.auth_framework.providers.header import ( DEFAULT_OPERATION_ACCESS, @@ -64,6 +65,35 @@ def test_default_operation_access_covers_every_operation(): assert not missing, f"Operations missing default access mapping: {missing}" +# --------------------------------------------------------------------------- +# NoAuthProvider +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_no_auth_provider_allows_any_operation(): + provider = NoAuthProvider(default_namespace_key="ns-local") + + principal = await provider.authorize( + _build_request(), + Operation.CONTROLS_DELETE, + ) + + assert principal == Principal(namespace_key="ns-local") + + +@pytest.mark.asyncio +async def test_no_auth_provider_grants_runtime_exchange_scope(): + provider = NoAuthProvider() + + principal = await provider.authorize( + _build_request(), + Operation.RUNTIME_TOKEN_EXCHANGE, + ) + + assert principal.scopes == (Operation.RUNTIME_USE.value,) + + # --------------------------------------------------------------------------- # HeaderAuthProvider # --------------------------------------------------------------------------- @@ -101,7 +131,7 @@ async def test_header_provider_public_returns_default_namespace(): @pytest.mark.asyncio -async def test_header_provider_authenticated_calls_legacy_validator(): +async def test_header_provider_authenticated_calls_local_validator(): provider = HeaderAuthProvider() expected_client = MagicMock(is_admin=False, key_id="abc12345") @@ -945,6 +975,70 @@ def test_runtime_ttl_loader_accepts_max(monkeypatch): ) +def test_build_default_provider_accepts_none_mode(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "none") + + assert isinstance(auth_config._build_default_provider(), NoAuthProvider) + + +def test_resolve_runtime_mode_defaults_to_api_key_without_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + assert auth_config._resolve_runtime_mode() == "api_key" + + +def test_resolve_runtime_mode_defaults_to_jwt_with_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + assert auth_config._resolve_runtime_mode() == "jwt" + + +def test_configure_runtime_none_installs_no_auth_provider(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "none") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), NoAuthProvider) + assert auth_config.runtime_auth_config() is None + + +def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "api_key") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), HeaderAuthProvider) + assert auth_config.runtime_auth_config() is None + + +def test_configure_runtime_jwt_requires_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "jwt") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + with pytest.raises(RuntimeError, match="requires AGENT_CONTROL_RUNTIME_TOKEN_SECRET"): + auth_config.configure_auth_from_env() + + def test_configure_then_reconfigure_clears_runtime_override(monkeypatch): """Reconfiguring without a runtime secret must drop the override.""" from agent_control_server.auth_framework import config as auth_config diff --git a/server/tests/test_controls_additional.py b/server/tests/test_controls_additional.py index b4922b9d..dfbb15f5 100644 --- a/server/tests/test_controls_additional.py +++ b/server/tests/test_controls_additional.py @@ -8,19 +8,19 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from agent_control_evaluators import RegexEvaluatorConfig +from agent_control_models import ConditionNode from fastapi.testclient import TestClient from sqlalchemy import text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from agent_control_models import ConditionNode +from agent_control_server.auth_framework import Principal from agent_control_server.db import get_async_db -from agent_control_server.models import Control - -from agent_control_evaluators import RegexEvaluatorConfig from agent_control_server.endpoints import controls as controls_module from agent_control_server.main import app +from agent_control_server.models import DEFAULT_NAMESPACE_KEY, Control from .conftest import engine from .utils import VALID_CONTROL_PAYLOAD @@ -1106,7 +1106,12 @@ def model_dump(self, *args: object, **kwargs: object) -> dict[str, object]: request = SimpleNamespace(data=DummyData(payload)) # When: updating the control data with a non-Pydantic selector - response = await controls_module.set_control_data(control.id, request, async_db) + response = await controls_module.set_control_data( + control.id, + request, + async_db, + principal=Principal(namespace_key=DEFAULT_NAMESPACE_KEY), + ) # Then: the update succeeds and uses the original selector serialization assert response.success is True diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py index 27832f18..11105589 100644 --- a/server/tests/test_controls_auth.py +++ b/server/tests/test_controls_auth.py @@ -4,14 +4,13 @@ import uuid -import pytest from fastapi.testclient import TestClient -from agent_control_server.config import auth_settings +from agent_control_server.auth_framework import set_authorizer +from agent_control_server.auth_framework.providers import NoAuthProvider from .utils import VALID_CONTROL_PAYLOAD - _CONTROLS_URL = "/api/v1/controls" _TEMPLATES_URL = "/api/v1/control-templates" @@ -234,18 +233,19 @@ def test_non_admin_cannot_delete_control( assert resp.status_code == 403, resp.text -def test_non_admin_cannot_validate_control_data( +def test_non_admin_can_validate_control_data( non_admin_client: TestClient, ) -> None: - """``/controls/validate`` requires ``CONTROLS_CREATE``.""" + """``/controls/validate`` requires ``CONTROLS_READ``.""" # When: a non-admin attempts to validate a draft payload resp = non_admin_client.post( f"{_CONTROLS_URL}/validate", json={"data": VALID_CONTROL_PAYLOAD}, ) - # Then: validation requires CONTROLS_CREATE. - assert resp.status_code == 403, resp.text + # Then: validation is allowed for authenticated non-admin callers + assert resp.status_code == 200, resp.text + assert resp.json()["success"] is True def test_non_admin_cannot_render_template(non_admin_client: TestClient) -> None: @@ -318,21 +318,16 @@ def test_unauthenticated_cannot_render_template( # --------------------------------------------------------------------------- -# No-auth deployment mode: api_key_enabled=False bypasses every gate. +# No-auth deployment mode: explicit provider bypasses every gate. # --------------------------------------------------------------------------- def test_no_auth_mode_allows_writes_without_credentials( unauthenticated_client: TestClient, - monkeypatch: pytest.MonkeyPatch, ) -> None: - """When ``api_key_enabled`` is False, the ``HeaderAuthProvider`` - short-circuits to a non-admin ``Principal`` for every operation, - including admin-level writes. This pins the "no auth" deployment - path so a future refactor can't silently start enforcing. - """ - # Given: api_key_enabled is False (single-tenant OSS dev mode) - monkeypatch.setattr(auth_settings, "api_key_enabled", False) + """Explicit no-auth provider allows every operation without credentials.""" + # Given: the request-auth framework is in no-auth mode + set_authorizer(NoAuthProvider()) # When: an unauthenticated client creates a control resp = unauthenticated_client.put( diff --git a/server/tests/test_principal_namespace_flow.py b/server/tests/test_principal_namespace_flow.py new file mode 100644 index 00000000..40ecd216 --- /dev/null +++ b/server/tests/test_principal_namespace_flow.py @@ -0,0 +1,141 @@ +"""HTTP-level coverage for principal-derived namespace scoping.""" + +from __future__ import annotations + +import uuid +from typing import Any + +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from agent_control_server.auth_framework import ( + Operation, + Principal, + set_authorizer, +) + +from .utils import VALID_CONTROL_PAYLOAD + + +class HeaderNamespaceAuthorizer: + """Test authorizer that maps a request header to ``Principal.namespace_key``.""" + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del context + scopes = ( + (Operation.RUNTIME_USE.value,) + if operation is Operation.RUNTIME_TOKEN_EXCHANGE + else () + ) + return Principal( + namespace_key=request.headers.get("X-Test-Namespace", "default"), + is_admin=True, + scopes=scopes, + ) + + +def _client(app: FastAPI, namespace_key: str) -> TestClient: + return TestClient( + app, + raise_server_exceptions=True, + headers={"X-Test-Namespace": namespace_key}, + ) + + +def _agent_payload(agent_name: str) -> dict[str, Any]: + return { + "agent": { + "agent_name": agent_name, + "agent_description": "test agent", + "agent_version": "1.0", + }, + "steps": [], + } + + +def _evaluation_payload(agent_name: str) -> dict[str, Any]: + return { + "agent_name": agent_name, + "step": { + "type": "llm", + "name": "test-step", + "input": "x marks the spot", + "context": {}, + }, + "stage": "pre", + "target_type": "env", + "target_id": "prod", + } + + +def test_principal_namespace_scopes_management_and_runtime(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + register_a = ns_a.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)) + register_b = ns_b.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)) + assert register_a.status_code == 200, register_a.text + assert register_b.status_code == 200, register_b.text + + create_control = ns_a.put( + "/api/v1/controls", + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + assert create_control.status_code == 200, create_control.text + control_id = int(create_control.json()["control_id"]) + + policy = ns_a.put( + "/api/v1/policies", + json={"name": f"policy-{uuid.uuid4().hex[:12]}"}, + ) + assert policy.status_code == 200, policy.text + policy_id = int(policy.json()["policy_id"]) + attach_to_policy = ns_a.post(f"/api/v1/policies/{policy_id}/controls/{control_id}") + assert attach_to_policy.status_code == 200, attach_to_policy.text + + binding = ns_a.put( + "/api/v1/control-bindings", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": True, + }, + ) + assert binding.status_code == 200, binding.text + + assert ns_b.get(f"/api/v1/controls/{control_id}").status_code == 404 + assert ns_b.get(f"/api/v1/policies/{policy_id}/controls").status_code == 404 + assert ns_b.get("/api/v1/control-bindings").json()["bindings"] == [] + + eval_a = ns_a.post("/api/v1/evaluation", json=_evaluation_payload(agent_name)) + assert eval_a.status_code == 200, eval_a.text + assert eval_a.json()["is_safe"] is False + assert eval_a.json()["matches"][0]["control_id"] == control_id + + eval_b = ns_b.post("/api/v1/evaluation", json=_evaluation_payload(agent_name)) + assert eval_b.status_code == 200, eval_b.text + assert eval_b.json()["is_safe"] is True + + +def test_duplicate_control_names_allowed_across_principal_namespaces(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + control_name = f"control-{uuid.uuid4().hex[:12]}" + payload = {"name": control_name, "data": VALID_CONTROL_PAYLOAD} + + assert ns_a.put("/api/v1/controls", json=payload).status_code == 200 + assert ns_b.put("/api/v1/controls", json=payload).status_code == 200 diff --git a/server/tests/test_target_merged_contract.py b/server/tests/test_target_merged_contract.py index 295a85e2..62891ba5 100644 --- a/server/tests/test_target_merged_contract.py +++ b/server/tests/test_target_merged_contract.py @@ -232,9 +232,9 @@ def test_target_binding_de_duplicated_against_direct_attachment( async def _insert_agent_in_namespace(async_db, *, name: str, namespace_key: str) -> None: """Insert an Agent row directly so the test can simulate a foreign namespace. - The endpoint's ``get_namespace_key`` returns the default namespace; this - helper sidesteps the resolver to seed an agent that the request-time - code path should not be able to reach. + The default test authorizer returns the default namespace; this helper + sidesteps the authorizer to seed an agent that the request-time code + path should not be able to reach. """ from agent_control_server.models import Agent From 4a87b4ae9e387ad2622d0c0cd198979e0ad60f53 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 20:20:00 +0530 Subject: [PATCH 02/20] chore(sdk-ts): regenerate client docs --- .../src/generated/funcs/agents-get-evaluator.ts | 3 +-- sdks/typescript/src/generated/funcs/agents-get.ts | 3 +-- .../typescript/src/generated/funcs/agents-init.ts | 1 + .../src/generated/funcs/agents-list-controls.ts | 2 +- .../src/generated/funcs/agents-list-evaluators.ts | 3 +-- .../typescript/src/generated/funcs/agents-list.ts | 2 +- .../src/generated/funcs/agents-update.ts | 1 + .../generated/funcs/control-bindings-create.ts | 6 +----- .../src/generated/funcs/control-bindings-list.ts | 3 +-- sdks/typescript/src/generated/sdk/agents.ts | 15 +++++++-------- .../src/generated/sdk/control-bindings.ts | 9 ++------- 11 files changed, 18 insertions(+), 30 deletions(-) diff --git a/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts b/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts index acb364eb..ceca1ec0 100644 --- a/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts +++ b/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts @@ -37,8 +37,7 @@ import { Result } from "../types/fp.js"; * agent_name: Agent identifier * evaluator_name: Name of the evaluator * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * EvaluatorSchemaItem with schema details diff --git a/sdks/typescript/src/generated/funcs/agents-get.ts b/sdks/typescript/src/generated/funcs/agents-get.ts index 9724edbf..142f3062 100644 --- a/sdks/typescript/src/generated/funcs/agents-get.ts +++ b/sdks/typescript/src/generated/funcs/agents-get.ts @@ -38,8 +38,7 @@ import { Result } from "../types/fp.js"; * Args: * agent_name: Agent identifier * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * GetAgentResponse with agent metadata and step list diff --git a/sdks/typescript/src/generated/funcs/agents-init.ts b/sdks/typescript/src/generated/funcs/agents-init.ts index 9d63358d..7150b2a4 100644 --- a/sdks/typescript/src/generated/funcs/agents-init.ts +++ b/sdks/typescript/src/generated/funcs/agents-init.ts @@ -51,6 +51,7 @@ import { Result } from "../types/fp.js"; * Args: * request: Agent metadata and step schemas * db: Database session (injected) + * principal: Authorized request principal * * Returns: * InitAgentResponse with created flag and the effective controls diff --git a/sdks/typescript/src/generated/funcs/agents-list-controls.ts b/sdks/typescript/src/generated/funcs/agents-list-controls.ts index 661c5509..d1e5b27d 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-controls.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-controls.ts @@ -53,7 +53,7 @@ import { Result } from "../types/fp.js"; * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * namespace_key: Namespace scoping for the resolution (injected) + * principal: Authorized request principal * * Returns: * AgentControlsResponse with controls matching the requested state filters diff --git a/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts b/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts index c4d8a4b2..4217e752 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts @@ -42,8 +42,7 @@ import { Result } from "../types/fp.js"; * cursor: Optional cursor for pagination (name of last evaluator from previous page) * limit: Pagination limit (default 20, max 100) * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * ListEvaluatorsResponse with evaluator schemas and pagination diff --git a/sdks/typescript/src/generated/funcs/agents-list.ts b/sdks/typescript/src/generated/funcs/agents-list.ts index fda7574d..f887d0b5 100644 --- a/sdks/typescript/src/generated/funcs/agents-list.ts +++ b/sdks/typescript/src/generated/funcs/agents-list.ts @@ -42,7 +42,7 @@ import { Result } from "../types/fp.js"; * limit: Pagination limit (default 20, max 100) * name: Optional name filter (case-insensitive partial match) * db: Database session (injected) - * namespace_key: Resolved namespace for the request + * principal: Authorized request principal * * Returns: * ListAgentsResponse with agent summaries and pagination info diff --git a/sdks/typescript/src/generated/funcs/agents-update.ts b/sdks/typescript/src/generated/funcs/agents-update.ts index e82644cf..aff9d827 100644 --- a/sdks/typescript/src/generated/funcs/agents-update.ts +++ b/sdks/typescript/src/generated/funcs/agents-update.ts @@ -40,6 +40,7 @@ import { Result } from "../types/fp.js"; * agent_name: Agent identifier * request: Lists of step/evaluator identifiers to remove * db: Database session (injected) + * principal: Authorized request principal * * Returns: * PatchAgentResponse with lists of actually removed items diff --git a/sdks/typescript/src/generated/funcs/control-bindings-create.ts b/sdks/typescript/src/generated/funcs/control-bindings-create.ts index 8412487e..71dee5a0 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-create.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-create.ts @@ -33,11 +33,7 @@ import { Result } from "../types/fp.js"; * Attach a control to an opaque external target. * * Each binding row is scoped to the request namespace as resolved by - * ``get_namespace_key``. The auth chain still runs via - * ``require_operation`` for authentication and authorization, but the - * storage namespace is taken from the same resolver the rest of the - * server uses so binding writes and runtime reads stay in lockstep - * until auth-derived namespace resolution lands across every endpoint. + * the active authorizer. */ export function controlBindingsCreate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-list.ts b/sdks/typescript/src/generated/funcs/control-bindings-list.ts index 5e7e87c3..5c90c7c2 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-list.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-list.ts @@ -35,8 +35,7 @@ import { Result } from "../types/fp.js"; * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by ``get_namespace_key`` so this - * listing stays in lockstep with the rest of the server's reads. + * storage namespace is resolved by the active authorizer. */ export function controlBindingsList( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/sdk/agents.ts b/sdks/typescript/src/generated/sdk/agents.ts index a22f4209..0a70e128 100644 --- a/sdks/typescript/src/generated/sdk/agents.ts +++ b/sdks/typescript/src/generated/sdk/agents.ts @@ -39,7 +39,7 @@ export class Agents extends ClientSDK { * limit: Pagination limit (default 20, max 100) * name: Optional name filter (case-insensitive partial match) * db: Database session (injected) - * namespace_key: Resolved namespace for the request + * principal: Authorized request principal * * Returns: * ListAgentsResponse with agent summaries and pagination info @@ -80,6 +80,7 @@ export class Agents extends ClientSDK { * Args: * request: Agent metadata and step schemas * db: Database session (injected) + * principal: Authorized request principal * * Returns: * InitAgentResponse with created flag and the effective controls @@ -106,8 +107,7 @@ export class Agents extends ClientSDK { * Args: * agent_name: Agent identifier * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * GetAgentResponse with agent metadata and step list @@ -140,6 +140,7 @@ export class Agents extends ClientSDK { * agent_name: Agent identifier * request: Lists of step/evaluator identifiers to remove * db: Database session (injected) + * principal: Authorized request principal * * Returns: * PatchAgentResponse with lists of actually removed items @@ -185,7 +186,7 @@ export class Agents extends ClientSDK { * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * namespace_key: Namespace scoping for the resolution (injected) + * principal: Authorized request principal * * Returns: * AgentControlsResponse with controls matching the requested state filters @@ -256,8 +257,7 @@ export class Agents extends ClientSDK { * cursor: Optional cursor for pagination (name of last evaluator from previous page) * limit: Pagination limit (default 20, max 100) * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * ListEvaluatorsResponse with evaluator schemas and pagination @@ -287,8 +287,7 @@ export class Agents extends ClientSDK { * agent_name: Agent identifier * evaluator_name: Name of the evaluator * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * EvaluatorSchemaItem with schema details diff --git a/sdks/typescript/src/generated/sdk/control-bindings.ts b/sdks/typescript/src/generated/sdk/control-bindings.ts index 5101ce74..dc6f20d3 100644 --- a/sdks/typescript/src/generated/sdk/control-bindings.ts +++ b/sdks/typescript/src/generated/sdk/control-bindings.ts @@ -23,8 +23,7 @@ export class ControlBindings extends ClientSDK { * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by ``get_namespace_key`` so this - * listing stays in lockstep with the rest of the server's reads. + * storage namespace is resolved by the active authorizer. */ async list( request?: @@ -46,11 +45,7 @@ export class ControlBindings extends ClientSDK { * Attach a control to an opaque external target. * * Each binding row is scoped to the request namespace as resolved by - * ``get_namespace_key``. The auth chain still runs via - * ``require_operation`` for authentication and authorization, but the - * storage namespace is taken from the same resolver the rest of the - * server uses so binding writes and runtime reads stay in lockstep - * until auth-derived namespace resolution lands across every endpoint. + * the active authorizer. */ async create( request: models.CreateControlBindingRequest, From 71fcd7b919dc40a121831d98a15a51ca2d63c5bb Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 23:07:04 +0530 Subject: [PATCH 03/20] fix(server): address runtime auth review feedback --- .../funcs/auth-runtime-token-exchange.ts | 9 ++--- .../funcs/control-bindings-create.ts | 4 +- .../funcs/control-bindings-delete.ts | 2 +- .../generated/funcs/control-bindings-get.ts | 5 +-- .../generated/funcs/control-bindings-list.ts | 2 +- .../funcs/control-bindings-update.ts | 2 +- sdks/typescript/src/generated/sdk/auth.ts | 9 ++--- .../src/generated/sdk/control-bindings.ts | 15 ++++--- .../auth_framework/core.py | 2 - .../auth_framework/providers/header.py | 2 - .../agent_control_server/endpoints/auth.py | 19 +++++---- .../endpoints/control_bindings.py | 15 ++++--- .../endpoints/controls.py | 6 +-- server/src/agent_control_server/namespace.py | 23 ----------- .../agent_control_server/services/controls.py | 23 ++++++----- server/tests/test_auth_framework.py | 24 +++++++++++ server/tests/test_controls_auth.py | 12 +++--- .../test_runtime_token_exchange_endpoint.py | 36 ++++++++++++++++- server/tests/test_services_controls.py | 40 ++++++++++++++----- 19 files changed, 146 insertions(+), 104 deletions(-) delete mode 100644 server/src/agent_control_server/namespace.py diff --git a/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts b/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts index 176693e3..7e8679c8 100644 --- a/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts +++ b/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts @@ -32,11 +32,10 @@ import { Result } from "../types/fp.js"; * @remarks * Mint a short-lived runtime token for the requested target. * - * The caller's credential is authenticated and authorized by the - * installed default authorizer; the resulting :class:`Principal` - * supplies the actor identity and (when the upstream surfaces it) - * the grant scopes and expiry. This endpoint then mints a local HS256 - * token whose lifetime cannot outlive the upstream grant. + * The caller's credential is authenticated and authorized before the + * resolved principal supplies the actor identity, grant scopes, and + * expiry. This endpoint then mints a local HS256 token whose lifetime + * cannot outlive the grant. * * Runtime auth must be enabled via * ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/sdks/typescript/src/generated/funcs/control-bindings-create.ts b/sdks/typescript/src/generated/funcs/control-bindings-create.ts index 71dee5a0..faf99923 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-create.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-create.ts @@ -32,8 +32,8 @@ import { Result } from "../types/fp.js"; * @remarks * Attach a control to an opaque external target. * - * Each binding row is scoped to the request namespace as resolved by - * the active authorizer. + * Each binding row is scoped to the namespace associated with the + * authenticated request. */ export function controlBindingsCreate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-delete.ts b/sdks/typescript/src/generated/funcs/control-bindings-delete.ts index 9e4d1293..9872a9b4 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-delete.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-delete.ts @@ -36,7 +36,7 @@ import { Result } from "../types/fp.js"; * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``POST /by-key:delete`` for - * target-scoped detach that forwards the target to the authorizer. + * target-scoped detach that includes the target in the request context. */ export function controlBindingsDelete( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-get.ts b/sdks/typescript/src/generated/funcs/control-bindings-get.ts index dafb7c7c..88b4e419 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-get.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-get.ts @@ -34,12 +34,11 @@ import { Result } from "../types/fp.js"; * Read a single control binding by surrogate ID. * * Authorization is namespace-wide: the binding's target identifiers - * are not forwarded to the upstream because they are only discoverable - * after the row is loaded, and ``require_operation`` is single-pass. + * are not available until after the row is loaded. * Callers whose authorization model requires per-target permissions * should use the natural-key endpoints (``PUT /by-key``, * ``POST /by-key:delete``) and the target-filtered list endpoint, all - * of which forward ``(target_type, target_id)`` to the authorizer. + * of which include ``(target_type, target_id)`` in the request context. */ export function controlBindingsGet( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-list.ts b/sdks/typescript/src/generated/funcs/control-bindings-list.ts index 5c90c7c2..a87ca89f 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-list.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-list.ts @@ -35,7 +35,7 @@ import { Result } from "../types/fp.js"; * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by the active authorizer. + * storage namespace is resolved from the authenticated request. */ export function controlBindingsList( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-update.ts b/sdks/typescript/src/generated/funcs/control-bindings-update.ts index b3faf800..b94520a2 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-update.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-update.ts @@ -36,7 +36,7 @@ import { Result } from "../types/fp.js"; * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``PUT /by-key`` for target-scoped - * upserts that forward the target to the authorizer. + * upserts that include the target in the request context. */ export function controlBindingsUpdate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/sdk/auth.ts b/sdks/typescript/src/generated/sdk/auth.ts index cf6de9ba..2d0cf74e 100644 --- a/sdks/typescript/src/generated/sdk/auth.ts +++ b/sdks/typescript/src/generated/sdk/auth.ts @@ -14,11 +14,10 @@ export class Auth extends ClientSDK { * @remarks * Mint a short-lived runtime token for the requested target. * - * The caller's credential is authenticated and authorized by the - * installed default authorizer; the resulting :class:`Principal` - * supplies the actor identity and (when the upstream surfaces it) - * the grant scopes and expiry. This endpoint then mints a local HS256 - * token whose lifetime cannot outlive the upstream grant. + * The caller's credential is authenticated and authorized before the + * resolved principal supplies the actor identity, grant scopes, and + * expiry. This endpoint then mints a local HS256 token whose lifetime + * cannot outlive the grant. * * Runtime auth must be enabled via * ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/sdks/typescript/src/generated/sdk/control-bindings.ts b/sdks/typescript/src/generated/sdk/control-bindings.ts index dc6f20d3..5a5bcf2b 100644 --- a/sdks/typescript/src/generated/sdk/control-bindings.ts +++ b/sdks/typescript/src/generated/sdk/control-bindings.ts @@ -23,7 +23,7 @@ export class ControlBindings extends ClientSDK { * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by the active authorizer. + * storage namespace is resolved from the authenticated request. */ async list( request?: @@ -44,8 +44,8 @@ export class ControlBindings extends ClientSDK { * @remarks * Attach a control to an opaque external target. * - * Each binding row is scoped to the request namespace as resolved by - * the active authorizer. + * Each binding row is scoped to the namespace associated with the + * authenticated request. */ async create( request: models.CreateControlBindingRequest, @@ -104,7 +104,7 @@ export class ControlBindings extends ClientSDK { * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``POST /by-key:delete`` for - * target-scoped detach that forwards the target to the authorizer. + * target-scoped detach that includes the target in the request context. */ async delete( request: @@ -125,12 +125,11 @@ export class ControlBindings extends ClientSDK { * Read a single control binding by surrogate ID. * * Authorization is namespace-wide: the binding's target identifiers - * are not forwarded to the upstream because they are only discoverable - * after the row is loaded, and ``require_operation`` is single-pass. + * are not available until after the row is loaded. * Callers whose authorization model requires per-target permissions * should use the natural-key endpoints (``PUT /by-key``, * ``POST /by-key:delete``) and the target-filtered list endpoint, all - * of which forward ``(target_type, target_id)`` to the authorizer. + * of which include ``(target_type, target_id)`` in the request context. */ async get( request: @@ -153,7 +152,7 @@ export class ControlBindings extends ClientSDK { * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``PUT /by-key`` for target-scoped - * upserts that forward the target to the authorizer. + * upserts that include the target in the request context. */ async update( request: diff --git a/server/src/agent_control_server/auth_framework/core.py b/server/src/agent_control_server/auth_framework/core.py index e0ea6da7..058169de 100644 --- a/server/src/agent_control_server/auth_framework/core.py +++ b/server/src/agent_control_server/auth_framework/core.py @@ -52,11 +52,9 @@ class Operation(StrEnum): POLICIES_READ = "policies.read" POLICIES_CREATE = "policies.create" POLICIES_UPDATE = "policies.update" - POLICIES_DELETE = "policies.delete" AGENTS_READ = "agents.read" AGENTS_CREATE = "agents.create" AGENTS_UPDATE = "agents.update" - AGENTS_DELETE = "agents.delete" RUNTIME_USE = "runtime.use" diff --git a/server/src/agent_control_server/auth_framework/providers/header.py b/server/src/agent_control_server/auth_framework/providers/header.py index 228ec443..16760768 100644 --- a/server/src/agent_control_server/auth_framework/providers/header.py +++ b/server/src/agent_control_server/auth_framework/providers/header.py @@ -45,11 +45,9 @@ class AccessLevel(Enum): Operation.POLICIES_READ: AccessLevel.AUTHENTICATED, Operation.POLICIES_CREATE: AccessLevel.ADMIN, Operation.POLICIES_UPDATE: AccessLevel.ADMIN, - Operation.POLICIES_DELETE: AccessLevel.ADMIN, Operation.AGENTS_READ: AccessLevel.AUTHENTICATED, Operation.AGENTS_CREATE: AccessLevel.AUTHENTICATED, Operation.AGENTS_UPDATE: AccessLevel.ADMIN, - Operation.AGENTS_DELETE: AccessLevel.ADMIN, Operation.RUNTIME_TOKEN_EXCHANGE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_USE: AccessLevel.AUTHENTICATED, } diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index f80cd2fa..b1ade969 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -2,13 +2,13 @@ The runtime auth flow is two-phase: this endpoint is phase one. The caller presents a long-lived credential plus ``(target_type, -target_id)``; the default authorizer authenticates the credential and -authorizes the implied -:data:`Operation.RUNTIME_TOKEN_EXCHANGE`. On success, this endpoint +target_id)``; the configured authorization provider authenticates the +credential and authorizes the implied +``runtime.token_exchange`` operation. On success, this endpoint mints a short-lived local runtime token bound to the supplied target and returns it. Subsequent target-bearing runtime calls present the returned token, which is verified locally by -:class:`LocalJwtVerifyProvider`. +the runtime JWT provider. """ from __future__ import annotations @@ -56,7 +56,7 @@ class RuntimeTokenExchangeResponse(BaseModel): async def _exchange_context(request: Request) -> dict[str, Any]: - """Surface target identifiers to the authorizer's context. + """Surface target identifiers to the authorization context. Reads the request body once. FastAPI caches the parsed body, so the endpoint's own Pydantic body model still binds normally. @@ -89,11 +89,10 @@ async def runtime_token_exchange( ) -> RuntimeTokenExchangeResponse: """Mint a short-lived runtime token for the requested target. - The caller's credential is authenticated and authorized by the - installed default authorizer; the resulting :class:`Principal` - supplies the actor identity and (when the upstream surfaces it) - the grant scopes and expiry. This endpoint then mints a local HS256 - token whose lifetime cannot outlive the upstream grant. + The caller's credential is authenticated and authorized before the + resolved principal supplies the actor identity, grant scopes, and + expiry. This endpoint then mints a local HS256 token whose lifetime + cannot outlive the grant. Runtime auth must be enabled via ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/server/src/agent_control_server/endpoints/control_bindings.py b/server/src/agent_control_server/endpoints/control_bindings.py index d2fe4b44..87386723 100644 --- a/server/src/agent_control_server/endpoints/control_bindings.py +++ b/server/src/agent_control_server/endpoints/control_bindings.py @@ -102,8 +102,8 @@ async def create_control_binding( ) -> CreateControlBindingResponse: """Attach a control to an opaque external target. - Each binding row is scoped to the request namespace as resolved by - the active authorizer. + Each binding row is scoped to the namespace associated with the + authenticated request. """ service = ControlBindingsService(db) binding = await service.create_binding( @@ -153,7 +153,7 @@ async def list_control_bindings( cursor-based pagination. Bindings are ordered by ID descending (newest first). The cursor is opaque to clients: pass back the ``next_cursor`` value verbatim to fetch the following page. The - storage namespace is resolved by the active authorizer. + storage namespace is resolved from the authenticated request. """ parsed_cursor: int | None if cursor is None: @@ -201,12 +201,11 @@ async def get_control_binding( """Read a single control binding by surrogate ID. Authorization is namespace-wide: the binding's target identifiers - are not forwarded to the upstream because they are only discoverable - after the row is loaded, and ``require_operation`` is single-pass. + are not available until after the row is loaded. Callers whose authorization model requires per-target permissions should use the natural-key endpoints (``PUT /by-key``, ``POST /by-key:delete``) and the target-filtered list endpoint, all - of which forward ``(target_type, target_id)`` to the authorizer. + of which include ``(target_type, target_id)`` in the request context. """ service = ControlBindingsService(db) binding = await service.get_binding_or_404( @@ -232,7 +231,7 @@ async def patch_control_binding( See the GET-by-id docstring for the authorization scope: this route is namespace-wide because the target identifiers are not available before the binding is loaded. Use ``PUT /by-key`` for target-scoped - upserts that forward the target to the authorizer. + upserts that include the target in the request context. """ service = ControlBindingsService(db) binding = await service.set_enabled( @@ -260,7 +259,7 @@ async def delete_control_binding( See the GET-by-id docstring for the authorization scope: this route is namespace-wide because the target identifiers are not available before the binding is loaded. Use ``POST /by-key:delete`` for - target-scoped detach that forwards the target to the authorizer. + target-scoped detach that includes the target in the request context. """ service = ControlBindingsService(db) await service.delete_binding(namespace_key=principal.namespace_key, binding_id=binding_id) diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 5b01593c..00d2b710 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -787,12 +787,12 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) -# Authorized as CONTROLS_READ: validate exercises the materialization -# path but does not mutate stored control data. +# Authorized as CONTROLS_CREATE: validate exercises the same materialization +# path as create/update authoring flows, even though it does not save. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. diff --git a/server/src/agent_control_server/namespace.py b/server/src/agent_control_server/namespace.py deleted file mode 100644 index 30e30be5..00000000 --- a/server/src/agent_control_server/namespace.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Namespace resolution for request-scoped scoping. - -V1 always resolves to the default namespace. The function exists as a -single seam so a future change can switch every namespace-scoped -endpoint to a real per-request resolver without touching each call -site. Overriding the dependency in V1 is not supported: only this -binding/evaluation layer reads it; controls, agents, and policies still -write under the default namespace, so an override here would create -inconsistent rows. Future work will thread a single resolver through -every write path together. -""" - -from __future__ import annotations - -from .models import DEFAULT_NAMESPACE_KEY - - -def get_namespace_key() -> str: - """Return the namespace_key for the current request. - - V1 returns ``DEFAULT_NAMESPACE_KEY`` unconditionally. - """ - return DEFAULT_NAMESPACE_KEY diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index 41a62282..e3a5fd26 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -20,7 +20,6 @@ from ..errors import APIValidationError, NotFoundError from ..models import ( - DEFAULT_NAMESPACE_KEY, Control, ControlBinding, ControlVersion, @@ -100,7 +99,7 @@ def __init__(self, db: AsyncSession) -> None: def create_control( self, *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, name: str, data: dict[str, Any], ) -> Control: @@ -161,17 +160,19 @@ async def get_active_control_or_404( control_id: int, *, for_update: bool = False, - namespace_key: str | None = None, + namespace_key: str, ) -> Control: """Load an active control row or raise CONTROL_NOT_FOUND. - When ``namespace_key`` is supplied, the lookup is scoped to that - namespace; a control that exists only in another namespace - surfaces as 404 (non-disclosing). + The lookup is scoped to the supplied namespace; a control that + exists only in another namespace surfaces as 404 + (non-disclosing). """ - stmt = select(Control).where(Control.id == control_id, Control.deleted_at.is_(None)) - if namespace_key is not None: - stmt = stmt.where(Control.namespace_key == namespace_key) + stmt = select(Control).where( + Control.id == control_id, + Control.namespace_key == namespace_key, + Control.deleted_at.is_(None), + ) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) @@ -190,7 +191,7 @@ async def active_control_name_exists( self, name: str, *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, exclude_control_id: int | None = None, ) -> bool: """Return whether an active control already uses the provided name.""" @@ -537,7 +538,7 @@ async def list_active_control_counts_by_agent( self, agent_names: Sequence[str], *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> dict[str, int]: """Return active control counts keyed by agent name.""" if not agent_names: diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 2d39bfa3..799b2d52 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -20,6 +20,7 @@ AccessLevel, HeaderAuthProvider, HttpUpstreamAuthProvider, + LocalJwtVerifyProvider, NoAuthProvider, ) from agent_control_server.auth_framework.providers.header import ( @@ -1029,6 +1030,29 @@ def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): assert auth_config.runtime_auth_config() is None +@pytest.mark.asyncio +async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "jwt") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + try: + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.CONTROLS_READ), HttpUpstreamAuthProvider) + assert isinstance(get_authorizer(Operation.RUNTIME_USE), LocalJwtVerifyProvider) + runtime_config = auth_config.runtime_auth_config() + assert runtime_config is not None + assert runtime_config.secret == _TEST_SECRET + finally: + await auth_config.teardown_auth() + + def test_configure_runtime_jwt_requires_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py index 11105589..7975dad9 100644 --- a/server/tests/test_controls_auth.py +++ b/server/tests/test_controls_auth.py @@ -4,10 +4,9 @@ import uuid -from fastapi.testclient import TestClient - from agent_control_server.auth_framework import set_authorizer from agent_control_server.auth_framework.providers import NoAuthProvider +from fastapi.testclient import TestClient from .utils import VALID_CONTROL_PAYLOAD @@ -233,19 +232,18 @@ def test_non_admin_cannot_delete_control( assert resp.status_code == 403, resp.text -def test_non_admin_can_validate_control_data( +def test_non_admin_cannot_validate_control_data( non_admin_client: TestClient, ) -> None: - """``/controls/validate`` requires ``CONTROLS_READ``.""" + """``/controls/validate`` requires ``CONTROLS_CREATE``.""" # When: a non-admin attempts to validate a draft payload resp = non_admin_client.post( f"{_CONTROLS_URL}/validate", json={"data": VALID_CONTROL_PAYLOAD}, ) - # Then: validation is allowed for authenticated non-admin callers - assert resp.status_code == 200, resp.text - assert resp.json()["success"] is True + # Then: validation is admin-only + assert resp.status_code == 403, resp.text def test_non_admin_cannot_render_template(non_admin_client: TestClient) -> None: diff --git a/server/tests/test_runtime_token_exchange_endpoint.py b/server/tests/test_runtime_token_exchange_endpoint.py index 8d333a5c..1b1edae2 100644 --- a/server/tests/test_runtime_token_exchange_endpoint.py +++ b/server/tests/test_runtime_token_exchange_endpoint.py @@ -11,8 +11,6 @@ from datetime import UTC, datetime, timedelta import pytest -from fastapi.testclient import TestClient - from agent_control_server.auth_framework import Operation, Principal from agent_control_server.auth_framework.config import ( RuntimeAuthConfig, @@ -25,6 +23,7 @@ from agent_control_server.auth_framework.providers import ( LocalJwtVerifyProvider, ) +from fastapi.testclient import TestClient _TEST_SECRET = "test-runtime-secret-12345678901234567890" @@ -180,6 +179,39 @@ async def test_exchange_then_verify_full_round_trip(client: TestClient, runtime_ assert principal.caller_id == "actor-rt" +def test_evaluation_rejects_runtime_jwt_for_wrong_target( + client: TestClient, + runtime_config_enabled, +): + """A runtime JWT minted for one target cannot be used for another target.""" + stub = _StubExchangeAuthorizer(actor_id="actor-rt", scopes=("runtime.use",)) + clear_authorizers() + set_authorizer(stub) + set_authorizer(LocalJwtVerifyProvider(secret=_TEST_SECRET), operation=Operation.RUNTIME_USE) + + exchange = client.post( + "/api/v1/auth/runtime-token-exchange", + json={"target_type": "log_stream", "target_id": "ls-allowed"}, + ) + assert exchange.status_code == 200, exchange.text + token = exchange.json()["token"] + + response = client.post( + "/api/v1/evaluation", + headers={"Authorization": f"Bearer {token}"}, + json={ + "agent_name": "agent", + "step": {"type": "llm", "name": "step", "input": "hello"}, + "stage": "pre", + "target_type": "log_stream", + "target_id": "ls-other", + }, + ) + + assert response.status_code == 403, response.text + assert response.json()["detail"] == "Runtime token target_id does not match the request." + + def test_exchange_endpoint_502_when_upstream_grant_already_expired( client: TestClient, runtime_config_enabled, diff --git a/server/tests/test_services_controls.py b/server/tests/test_services_controls.py index b858c527..3815f26b 100644 --- a/server/tests/test_services_controls.py +++ b/server/tests/test_services_controls.py @@ -8,10 +8,6 @@ import pytest from agent_control_models.errors import ErrorCode -from sqlalchemy import insert, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Session - from agent_control_server.errors import APIValidationError from agent_control_server.models import ( DEFAULT_NAMESPACE_KEY, @@ -27,6 +23,9 @@ from agent_control_server.services.controls import ( ControlService, ) +from sqlalchemy import insert, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session from .conftest import AsyncSessionTest, engine from .utils import VALID_CONTROL_PAYLOAD @@ -70,7 +69,11 @@ async def _create_versioned_control( async with AsyncSessionTest() as session: service = ControlService(session) - control = service.create_control(name=control_name, data=control_data) + control = service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, + name=control_name, + data=control_data, + ) await service.create_version( control, event_type="created", @@ -143,6 +146,7 @@ async def test_create_control_transaction_rollback_does_not_persist_control_or_v async with AsyncSessionTest() as session: service = ControlService(session) control = service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, name=control_name, data=deepcopy(VALID_CONTROL_PAYLOAD), ) @@ -167,7 +171,10 @@ async def test_replace_control_data_transaction_rollback_preserves_prior_state() async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) updated_data = deepcopy(control.data) updated_data["description"] = "Should not persist" service.replace_control_data(control, data=updated_data) @@ -194,7 +201,10 @@ async def test_patch_mutation_transaction_rollback_preserves_prior_state() -> No async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) service.rename_control(control, name=f"{control_name}-renamed") service.set_control_enabled(control, enabled=False) await service.create_version( @@ -221,7 +231,10 @@ async def test_delete_control_transaction_rollback_preserves_active_state() -> N async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) service.mark_control_deleted(control, deleted_at=dt.datetime.now(dt.UTC)) await service.create_version( control, @@ -511,7 +524,10 @@ async def test_list_active_control_counts_by_agent_deduplicates_and_filters_inac await async_db.commit() # When: counting active controls for the agent - counts = await ControlService(async_db).list_active_control_counts_by_agent([agent.name]) + counts = await ControlService(async_db).list_active_control_counts_by_agent( + [agent.name], + namespace_key=DEFAULT_NAMESPACE_KEY, + ) # Then: active controls are deduplicated and inactive controls are excluded assert counts == {agent.name: 2} @@ -572,6 +588,7 @@ async def test_create_version_allocates_sequential_numbers_under_concurrent_muta async with AsyncSessionTest() as setup_session: setup_service = ControlService(setup_session) control = setup_service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, name=f"control-{uuid.uuid4()}", data=deepcopy(VALID_CONTROL_PAYLOAD), ) @@ -592,7 +609,10 @@ async def mutate_and_version(description: str) -> None: async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) updated_data = deepcopy(control.data) updated_data["description"] = description service.replace_control_data(control, data=updated_data) From c0feed9fc42e42a800464c4aa21d89f39b0cbcb7 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 16:43:39 +0530 Subject: [PATCH 04/20] feat(server): operator-configurable extra forwarded headers on HttpUpstream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The default forward set (X-API-Key, Authorization, Cookie) only covers credential headers Agent Control itself reads. Deployments whose upstream authenticates against a different header name (e.g., a deployer-specific API-key header) had no way to surface that credential through HttpUpstreamAuthProvider — the inbound header reached AC but never crossed the upstream call. Add an extra_forward_headers config field on HttpUpstreamConfig (defaulting to the empty tuple) that operators populate via the new AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS env var (comma- separated). The provider's _forward_headers iterates over the union of the default set and the extras, deduplicating case-insensitively so a duplicate name (cross-set or within extras) does not produce two copies on the wire. Tests: - forwards a configured extra header alongside defaults - default forward set unchanged when extras are empty - extras dedupe against defaults case-insensitively - _parse_extra_forward_headers parametric: None / empty / single / multiple / whitespace / empty-entries / case-folded duplicates - configure_auth_from_env threads the parsed tuple onto the provider Lint clean, typecheck clean, full server suite (747) green. --- .../auth_framework/config.py | 29 +++++ .../auth_framework/providers/http_upstream.py | 20 ++- .../endpoints/controls.py | 3 +- server/tests/test_auth_framework.py | 115 ++++++++++++++++++ 4 files changed, 163 insertions(+), 4 deletions(-) diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index c8f428dc..8c39a2ec 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -46,6 +46,7 @@ _UPSTREAM_TIMEOUT_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_TIMEOUT_SECONDS" _UPSTREAM_TOKEN_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN" _UPSTREAM_TOKEN_HEADER_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER" +_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS" # Runtime flow. _RUNTIME_MODE_ENV = "AGENT_CONTROL_RUNTIME_AUTH_MODE" @@ -196,6 +197,9 @@ def _build_default_provider() -> RequestAuthorizer: timeout = float(os.environ.get(_UPSTREAM_TIMEOUT_ENV, "5.0")) token = os.environ.get(_UPSTREAM_TOKEN_ENV) token_header = os.environ.get(_UPSTREAM_TOKEN_HEADER_ENV, "X-Agent-Control-Service-Token") + extra_forward_headers = _parse_extra_forward_headers( + os.environ.get(_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV) + ) _logger.info("Default auth provider: http_upstream url=%s", url) return HttpUpstreamAuthProvider( HttpUpstreamConfig( @@ -203,6 +207,7 @@ def _build_default_provider() -> RequestAuthorizer: timeout_seconds=timeout, service_token=token, service_token_header=token_header, + extra_forward_headers=extra_forward_headers, ) ) raise RuntimeError( @@ -210,6 +215,30 @@ def _build_default_provider() -> RequestAuthorizer: ) +def _parse_extra_forward_headers(raw: str | None) -> tuple[str, ...]: + """Parse a comma-separated header list into a deduplicated tuple. + + Empty / unset env var returns an empty tuple. Whitespace around each + name is stripped. Empty entries (e.g. ``"X-A,,X-B"``) are dropped. + Order is preserved; duplicates (case-insensitive) are dropped after + the first occurrence. + """ + if not raw or not raw.strip(): + return () + seen: set[str] = set() + result: list[str] = [] + for raw_name in raw.split(","): + name = raw_name.strip() + if not name: + continue + lower = name.lower() + if lower in seen: + continue + seen.add(lower) + result.append(name) + return tuple(result) + + def _resolve_runtime_mode() -> str: raw = os.environ.get(_RUNTIME_MODE_ENV) if raw is None or not raw.strip(): diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index 8d5c850c..78ed9ae2 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -60,7 +60,7 @@ _logger = get_logger(__name__) -_FORWARDED_HEADERS = ("X-API-Key", "Authorization", "Cookie") +_DEFAULT_FORWARDED_HEADERS = ("X-API-Key", "Authorization", "Cookie") class _UpstreamGrant(BaseModel): @@ -136,6 +136,17 @@ class HttpUpstreamConfig: service_token_header: str = "X-Agent-Control-Service-Token" + extra_forward_headers: tuple[str, ...] = () + """Additional inbound request headers to forward to the upstream + on top of the default ``(X-API-Key, Authorization, Cookie)`` set. + + Use this when the upstream authenticates via a header the provider + does not forward by default (e.g., a deployer-specific API-key + header). Header lookups against the inbound request are + case-insensitive; an empty or absent inbound header is silently + dropped. Names duplicating the default set or each other (after + case-folding) are deduplicated.""" + class HttpUpstreamAuthProvider(RequestAuthorizer): """Delegates authorization to an upstream HTTP service.""" @@ -190,7 +201,12 @@ async def authorize( def _forward_headers(self, request: Request) -> dict[str, str]: headers: dict[str, str] = {} - for name in _FORWARDED_HEADERS: + seen: set[str] = set() + for name in (*_DEFAULT_FORWARDED_HEADERS, *self._config.extra_forward_headers): + lower = name.lower() + if lower in seen: + continue + seen.add(lower) value = request.headers.get(name) if value is not None: headers[name] = value diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 00d2b710..b4fa8d0b 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -787,8 +787,7 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) -# Authorized as CONTROLS_CREATE: validate exercises the same materialization -# path as create/update authoring flows, even though it does not save. +# Validation uses the authoring path, so require create access. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 799b2d52..dc3a1787 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -261,6 +261,75 @@ def factory(request: httpx.Request) -> httpx.Response: assert captured["headers"]["x-custom-token"] == "shh" +@pytest.mark.asyncio +async def test_http_upstream_forwards_extra_headers(): + # Given: a provider configured with an extra header in its forward list + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream( + factory, + config_overrides={"extra_forward_headers": ("X-Deployer-Auth",)}, + ) + + # When: the inbound request carries the extra header + inbound = _build_request(headers={"X-Deployer-Auth": "k_abc", "X-API-Key": "k1"}) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: both the default and the extra header reach the upstream + assert captured["headers"]["x-deployer-auth"] == "k_abc" + assert captured["headers"]["x-api-key"] == "k1" + + +@pytest.mark.asyncio +async def test_http_upstream_default_forward_set_unchanged(): + # Given: a provider with no extra_forward_headers + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream(factory) + + # When: the inbound carries an unlisted header alongside a default one + inbound = _build_request( + headers={"X-API-Key": "k1", "X-Deployer-Auth": "should-not-forward"} + ) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: only the default-set header reaches the upstream + assert captured["headers"].get("x-api-key") == "k1" + assert "x-deployer-auth" not in captured["headers"] + + +@pytest.mark.asyncio +async def test_http_upstream_extra_forward_dedupes_against_defaults(): + # Given: extra list duplicates a default header (different case) + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream( + factory, + config_overrides={"extra_forward_headers": ("x-api-key", "Authorization")}, + ) + + # When: inbound has both + inbound = _build_request(headers={"X-API-Key": "k1", "Authorization": "Bearer t"}) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: each header appears exactly once on the upstream request + forwarded = captured["headers"] + assert sum(1 for k in forwarded if k.lower() == "x-api-key") == 1 + assert sum(1 for k in forwarded if k.lower() == "authorization") == 1 + + @pytest.mark.asyncio @pytest.mark.parametrize( "status, expected", @@ -1053,6 +1122,52 @@ async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): await auth_config.teardown_auth() +@pytest.mark.parametrize( + "raw, expected", + [ + (None, ()), + ("", ()), + (" ", ()), + ("X-One", ("X-One",)), + ("X-One,X-Two", ("X-One", "X-Two")), + (" X-One , X-Two ", ("X-One", "X-Two")), + ("X-One,,X-Two", ("X-One", "X-Two")), + ("X-One,x-one,X-One", ("X-One",)), + ("X-A,X-B,x-a,X-C,X-b", ("X-A", "X-B", "X-C")), + ], +) +def test_parse_extra_forward_headers(raw, expected): + from agent_control_server.auth_framework.config import _parse_extra_forward_headers + + assert _parse_extra_forward_headers(raw) == expected + + +@pytest.mark.asyncio +async def test_configure_http_upstream_extra_forward_headers_env(monkeypatch): + """Setting the env var threads extra_forward_headers into the provider.""" + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.setenv( + "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS", + "X-Deployer-Auth, X-Deployer-Trace", + ) + + try: + auth_config.configure_auth_from_env() + provider = get_authorizer(Operation.CONTROLS_READ) + assert isinstance(provider, HttpUpstreamAuthProvider) + assert provider._config.extra_forward_headers == ( + "X-Deployer-Auth", + "X-Deployer-Trace", + ) + finally: + await auth_config.teardown_auth() + + def test_configure_runtime_jwt_requires_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config From 5399d5d8f7a8991cd6f0f2bac9bd33a2fe77c81b Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 21:36:45 +0530 Subject: [PATCH 05/20] fix(server): preserve default runtime auth fallback --- .../auth_framework/config.py | 37 +++++++++------- server/tests/test_auth_framework.py | 44 +++++++++++++++++-- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 8c39a2ec..595c3117 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -16,8 +16,8 @@ :class:`NoAuthProvider`, ``api_key`` uses :class:`HeaderAuthProvider`, and ``jwt`` uses :class:`LocalJwtVerifyProvider`. When the mode is unset, startup - preserves historical behavior by selecting ``jwt`` if - ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, otherwise ``api_key``. + selects ``jwt`` if ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set; + otherwise runtime falls through to the default authorizer. The ``runtime.token_exchange`` operation continues to flow through the default authorizer because the exchange itself is shaped like a management call (forward credential, get grant). @@ -96,10 +96,11 @@ def configure_auth_from_env() -> None: Runtime flow: - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=none``: :class:`NoAuthProvider`. - - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key`` (default when no runtime - token secret is configured): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key``: :class:`HeaderAuthProvider`. - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=jwt`` (default when a runtime token secret is configured): :class:`LocalJwtVerifyProvider`. + - unset mode without a runtime token secret: fall through to the default + authorizer. Clears any previously-installed default and operation overrides before installing fresh ones, so reconfiguration cannot leave @@ -121,20 +122,26 @@ def configure_auth_from_env() -> None: set_authorizer(default) _active_providers.append(default) - runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) - set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) - _active_providers.append(runtime_provider) - if runtime_mode == "jwt": + if runtime_mode == "default": _logger.info( - "Runtime auth provider: jwt override installed for %s", + "Runtime auth provider: default authorizer handles %s", Operation.RUNTIME_USE.value, ) else: - _logger.info( - "Runtime auth provider: %s override installed for %s", - runtime_mode, - Operation.RUNTIME_USE.value, - ) + runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) + set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) + _active_providers.append(runtime_provider) + if runtime_mode == "jwt": + _logger.info( + "Runtime auth provider: jwt override installed for %s", + Operation.RUNTIME_USE.value, + ) + else: + _logger.info( + "Runtime auth provider: %s override installed for %s", + runtime_mode, + Operation.RUNTIME_USE.value, + ) async def teardown_auth() -> None: @@ -242,7 +249,7 @@ def _parse_extra_forward_headers(raw: str | None) -> tuple[str, ...]: def _resolve_runtime_mode() -> str: raw = os.environ.get(_RUNTIME_MODE_ENV) if raw is None or not raw.strip(): - return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "api_key" + return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "default" mode = raw.strip().lower() if mode in {"none", "no_auth"}: diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index dc3a1787..20c58aed 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -7,7 +7,6 @@ import httpx import pytest - from agent_control_server.auth_framework.core import ( Operation, Principal, @@ -700,7 +699,6 @@ def test_runtime_token_rejects_naive_upstream_expires_at(): def test_runtime_token_rejects_management_token_passed_to_runtime_verify(): """A token without ``domain=runtime`` must be rejected by runtime verify.""" import jwt - from agent_control_server.auth_framework.runtime_token import ( RuntimeTokenError, verify_runtime_token, @@ -1053,13 +1051,13 @@ def test_build_default_provider_accepts_none_mode(monkeypatch): assert isinstance(auth_config._build_default_provider(), NoAuthProvider) -def test_resolve_runtime_mode_defaults_to_api_key_without_secret(monkeypatch): +def test_resolve_runtime_mode_defaults_to_default_without_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) - assert auth_config._resolve_runtime_mode() == "api_key" + assert auth_config._resolve_runtime_mode() == "default" def test_resolve_runtime_mode_defaults_to_jwt_with_secret(monkeypatch): @@ -1099,6 +1097,44 @@ def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): assert auth_config.runtime_auth_config() is None +def test_configure_runtime_unset_preserves_no_auth_default(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "none") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), NoAuthProvider) + assert auth_config.runtime_auth_config() is None + + +@pytest.mark.asyncio +async def test_configure_runtime_unset_preserves_http_upstream_default(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + try: + auth_config.configure_auth_from_env() + + default_provider = get_authorizer(Operation.CONTROLS_READ) + runtime_provider = get_authorizer(Operation.RUNTIME_USE) + assert isinstance(default_provider, HttpUpstreamAuthProvider) + assert runtime_provider is default_provider + assert auth_config.runtime_auth_config() is None + finally: + await auth_config.teardown_auth() + + @pytest.mark.asyncio async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): from agent_control_server.auth_framework import config as auth_config From 8c567da18f86ac76cfca40e742d00e9e4e26201b Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 14:39:01 +0530 Subject: [PATCH 06/20] fix(server): harden auth scoping --- docs/README.md | 1 + docs/auth.md | 148 ++++++++++++++++++ models/src/agent_control_models/server.py | 3 +- .../agent_control_server/endpoints/agents.py | 84 +++++++++- .../agent_control_server/endpoints/auth.py | 17 +- .../endpoints/controls.py | 44 +++++- server/tests/test_principal_namespace_flow.py | 33 +++- server/tests/test_target_merged_contract.py | 96 +++++++++++- 8 files changed, 402 insertions(+), 24 deletions(-) create mode 100644 docs/auth.md diff --git a/docs/README.md b/docs/README.md index 9b7cb757..e53dcf13 100644 --- a/docs/README.md +++ b/docs/README.md @@ -10,6 +10,7 @@ This repository keeps documentation concise. The full documentation lives on the - [Controls](https://docs.agentcontrol.dev/concepts/controls) — Define and configure control rules - [Reference](https://docs.agentcontrol.dev/core/reference) — SDK and server API reference - [Configuration](https://docs.agentcontrol.dev/core/configuration) — Environment variables, auth, and database settings +- [Server auth contract](auth.md) - Pluggable auth modes, HTTP upstream contract, and runtime JWT claims - [UI Quickstart](https://docs.agentcontrol.dev/core/ui-quickstart) — Run the dashboard and manage controls visually ## Examples diff --git a/docs/auth.md b/docs/auth.md new file mode 100644 index 00000000..5002faa8 --- /dev/null +++ b/docs/auth.md @@ -0,0 +1,148 @@ +# Server Auth Contract + +Agent Control keeps authentication and authorization provider-neutral. The server asks a configured provider whether a request may perform an operation, then scopes all data access with the returned `Principal`. + +## Operations + +Operations are stable strings. Deployers map them to their own permission model. + +```text +controls.read +controls.create +controls.update +controls.delete +policies.read +policies.create +policies.update +agents.read +agents.create +agents.update +control_bindings.read +control_bindings.write +runtime.token_exchange +runtime.use +``` + +## Principal + +Providers return a generic principal. Agent Control treats `namespace_key`, `caller_id`, `target_type`, and `target_id` as opaque strings. + +```json +{ + "namespace_key": "tenant-a", + "is_admin": false, + "caller_id": "user-or-key-id", + "target_type": "session", + "target_id": "target-123", + "scopes": ["runtime.use"], + "expires_at": "2026-05-11T15:00:00Z" +} +``` + +`namespace_key` is the tenancy boundary. Server queries filter by it, and namespace-aware foreign keys prevent cross-namespace references. + +## Auth Modes + +Management auth is selected by `AGENT_CONTROL_AUTH_MODE`. + +| Mode | Meaning | +| --- | --- | +| `none` | No credentials required. Intended for local development only. | +| `api_key` | Validate caller credentials locally with `AGENT_CONTROL_API_KEYS`. This is the default. `header` is accepted as a backwards-compatible alias. | +| `http_upstream` | POST each management authorization decision to `AGENT_CONTROL_AUTH_UPSTREAM_URL`. | + +Runtime auth is selected by `AGENT_CONTROL_RUNTIME_AUTH_MODE`. + +| Mode | Meaning | +| --- | --- | +| unset | Use `jwt` when `AGENT_CONTROL_RUNTIME_TOKEN_SECRET` is set. Otherwise runtime requests fall through to management auth. | +| `none` | No runtime credentials required. Intended for local development only. | +| `api_key` | Validate runtime requests with the same local API-key mechanism. | +| `jwt` | Require target-bound runtime tokens minted by `/api/v1/auth/runtime-token-exchange`. | + +Common combinations: + +| Management | Runtime | Use case | +| --- | --- | --- | +| `api_key` | unset | Existing standalone deployments. | +| `api_key` | `jwt` | Local management keys with short-lived target-bound runtime tokens. | +| `http_upstream` | `jwt` | External identity or authorization service for management, local token verify for high-volume runtime calls. | +| `none` | `none` | Single-process local development. Do not use in production. | + +## HTTP Upstream Contract + +When `AGENT_CONTROL_AUTH_MODE=http_upstream`, the server sends: + +```http +POST {AGENT_CONTROL_AUTH_UPSTREAM_URL} +``` + +```json +{ + "operation": "control_bindings.write", + "context": { + "target_type": "session", + "target_id": "target-123" + } +} +``` + +The provider forwards inbound `X-API-Key`, `Authorization`, and `Cookie` headers. Add deployer-specific header names with `AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS`, for example: + +```text +AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS=Vendor-API-Key,X-Workspace-Id +``` + +If `AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN` is set, it is forwarded on `AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER` or `X-Agent-Control-Service-Token` by default. + +A successful upstream response is: + +```json +{ + "namespace_key": "tenant-a", + "is_admin": false, + "caller_id": "user-or-key-id", + "target_type": "session", + "target_id": "target-123", + "scopes": ["runtime.use"], + "expires_at": "2026-05-11T15:00:00Z" +} +``` + +Only `namespace_key` is always required. `target_type` and `target_id` must be returned together when present. `expires_at` must include timezone information. + +Status handling: + +| Upstream status | Agent Control result | +| --- | --- | +| `200` | Parse the principal grant. | +| `401` | Authentication error. | +| `403` | Forbidden error. | +| `404` | Not found error. | +| `429` | `503` with a rate-limit detail and `Retry-After` hint when present. | +| Other statuses or malformed JSON | Fail closed with `503` or `502`. | + +## Runtime JWT Claims + +`/api/v1/auth/runtime-token-exchange` is a management-style request. The configured management provider authorizes `runtime.token_exchange` for the requested target. Agent Control then mints its own HS256 JWT with `AGENT_CONTROL_RUNTIME_TOKEN_SECRET`. + +The token payload contains: + +```json +{ + "iss": "agent-control/server", + "domain": "runtime", + "namespace_key": "tenant-a", + "actor_id": "user-or-key-id", + "target_type": "session", + "target_id": "target-123", + "scopes": ["runtime.use"], + "iat": 1778509800, + "exp": 1778510100, + "jti": "opaque-token-id" +} +``` + +Verification requires the expected issuer, `domain="runtime"`, a valid signature, an unexpired `exp`, and `runtime.use` in `scopes`. The token is accepted only for requests whose `target_type` and `target_id` match the bound target. + +The expiry is the earlier of `AGENT_CONTROL_RUNTIME_TOKEN_TTL_SECONDS` and the upstream grant's `expires_at` when supplied. Runtime token TTLs are capped at 86400 seconds. diff --git a/models/src/agent_control_models/server.py b/models/src/agent_control_models/server.py index 9b890b91..3529a5d4 100644 --- a/models/src/agent_control_models/server.py +++ b/models/src/agent_control_models/server.py @@ -640,7 +640,7 @@ class CreateControlBindingRequest(BaseModel): target_type: ControlBindingTargetField = Field( ..., - description="Opaque attachment kind (caller-defined; e.g. 'env', 'log_stream').", + description="Opaque attachment kind (caller-defined; e.g. 'environment', 'session').", ) target_id: ControlBindingTargetField = Field( ..., description="Opaque external identifier within the target_type." @@ -760,4 +760,3 @@ class DeleteControlBindingByKeyResponse(BaseModel): ), ) - diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index ac099911..57ca1ebc 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -29,20 +29,21 @@ SetPolicyResponse, StepKey, ) -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, Query, Request from jsonschema_rs import ValidationError as JSONSchemaValidationError from pydantic import BaseModel, ValidationError from sqlalchemy import delete, func, select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession -from ..auth_framework import Operation, Principal, require_operation +from ..auth_framework import Operation, Principal, get_authorizer, require_operation from ..db import get_async_db from ..errors import ( APIValidationError, BadRequestError, ConflictError, DatabaseError, + ForbiddenError, NotFoundError, ) from ..logging_utils import get_logger @@ -85,6 +86,81 @@ type StepKeyTuple = tuple[str, str] +def _complete_target_context( + target_type: object | None, + target_id: object | None, +) -> dict[str, str] | None: + """Return target context only when both halves are present strings.""" + if not isinstance(target_type, str) or not isinstance(target_id, str): + return None + if not target_type or not target_id: + return None + return {"target_type": target_type, "target_id": target_id} + + +async def _init_agent_target_context(request: Request) -> dict[str, str] | None: + """Extract optional target context from an ``initAgent`` body.""" + try: + body = await request.json() + except Exception: # noqa: BLE001 malformed JSON, defer to endpoint validation + return None + if not isinstance(body, dict): + return None + return _complete_target_context(body.get("target_type"), body.get("target_id")) + + +def _agent_controls_target_context(request: Request) -> dict[str, str] | None: + """Extract optional target context from ``GET /agents/{name}/controls``.""" + return _complete_target_context( + request.query_params.get("target_type"), + request.query_params.get("target_id"), + ) + + +async def _authorize_target_read_if_present( + request: Request, + context: dict[str, str] | None, +) -> Principal | None: + """Require target read authorization before returning target-merged controls.""" + if context is None: + return None + return await get_authorizer(Operation.CONTROL_BINDINGS_READ).authorize( + request, + Operation.CONTROL_BINDINGS_READ, + context, + ) + + +async def _init_agent_target_principal(request: Request) -> Principal | None: + return await _authorize_target_read_if_present( + request, + await _init_agent_target_context(request), + ) + + +async def _agent_controls_target_principal(request: Request) -> Principal | None: + return await _authorize_target_read_if_present( + request, + _agent_controls_target_context(request), + ) + + +def _ensure_target_principal_matches_namespace( + principal: Principal, + target_principal: Principal | None, +) -> None: + """Fail closed if the target authorization resolves to a different namespace.""" + if target_principal is None: + return + if target_principal.namespace_key == principal.namespace_key: + return + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="Target authorization resolved to a different namespace.", + hint="Ensure the credential is scoped to the requested target and namespace.", + ) + + # ============================================================================= # List Agents Models # ============================================================================= @@ -445,6 +521,7 @@ async def init_agent( request: InitAgentRequest, db: AsyncSession = Depends(get_async_db), principal: Principal = Depends(require_operation(Operation.AGENTS_CREATE)), + target_principal: Principal | None = Depends(_init_agent_target_principal), ) -> InitAgentResponse: """ Register a new agent or update an existing agent's steps and metadata. @@ -474,6 +551,7 @@ async def init_agent( InitAgentResponse with created flag and the effective controls """ namespace_key = principal.namespace_key + _ensure_target_principal_matches_namespace(principal, target_principal) # Check for evaluator name collisions with built-in evaluators builtin_names = _get_builtin_evaluator_names() @@ -1493,6 +1571,7 @@ async def list_agent_controls( ), db: AsyncSession = Depends(get_async_db), principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), + target_principal: Principal | None = Depends(_agent_controls_target_principal), ) -> AgentControlsResponse: """ List protection controls effective for an agent. @@ -1527,6 +1606,7 @@ async def list_agent_controls( HTTPException 404: Agent not found """ namespace_key = principal.namespace_key + _ensure_target_principal_matches_namespace(principal, target_principal) if (target_type is None) != (target_id is None): raise BadRequestError( diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index b1ade969..7125b64d 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -28,8 +28,10 @@ mint_runtime_token, ) from ..errors import APIError, BadRequestError +from ..logging_utils import get_logger router = APIRouter(prefix="/auth", tags=["auth"]) +_logger = get_logger(__name__) class RuntimeTokenExchangeRequest(BaseModel): @@ -38,7 +40,7 @@ class RuntimeTokenExchangeRequest(BaseModel): model_config = ConfigDict(extra="forbid") target_type: str = Field( - ..., description="Opaque target kind (e.g., ``log_stream``).", min_length=1 + ..., description="Opaque target kind (e.g., ``session``).", min_length=1 ) target_id: str = Field(..., description="Opaque target identifier.", min_length=1) @@ -175,6 +177,19 @@ async def runtime_token_exchange( hint="Check the runtime token configuration.", ) from exc + _logger.info( + "Runtime token exchanged", + extra={ + "namespace_key": claims.namespace_key, + "actor_id": claims.actor_id, + "target_type": claims.target_type, + "target_id": claims.target_id, + "scopes": list(claims.scopes), + "expires_at": claims.expires_at.isoformat(), + "jti": claims.jti, + }, + ) + return RuntimeTokenExchangeResponse( token=token, expires_at=claims.expires_at, diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index b4fa8d0b..6e6441e9 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -195,12 +195,17 @@ async def _render_and_validate_template_input( template_input: TemplateControlInput, *, db: AsyncSession, + namespace_key: str, enabled: bool = True, ) -> ControlDefinition: """Render a template-backed input and validate evaluator config.""" rendered = render_template_control_input(template_input, enabled=enabled) try: - await _validate_control_definition(rendered.control, db) + await _validate_control_definition( + rendered.control, + db, + namespace_key=namespace_key, + ) except APIValidationError as exc: raise remap_template_api_error( exc, @@ -214,6 +219,7 @@ async def _materialize_control_input( control_input: ControlDefinition | TemplateControlInput, *, db: AsyncSession, + namespace_key: str, current_payload: object | None = None, control_id: int | None = None, ) -> ControlDefinition | UnrenderedTemplateControl: @@ -226,6 +232,7 @@ async def _materialize_control_input( return await _render_and_validate_template_input( control_input, db=db, + namespace_key=namespace_key, enabled=enabled, ) @@ -244,6 +251,7 @@ async def _materialize_control_input( return await _render_and_validate_template_input( control_input, db=db, + namespace_key=namespace_key, enabled=enabled, ) @@ -262,12 +270,19 @@ async def _materialize_control_input( raise RuntimeError("control_id is required for template-backed raw updates") raise _template_backed_raw_update_conflict(control_id) - await _validate_control_definition(control_input, db) + await _validate_control_definition( + control_input, + db, + namespace_key=namespace_key, + ) return control_input async def _validate_control_definition( - control_def: ControlDefinition, db: AsyncSession + control_def: ControlDefinition, + db: AsyncSession, + *, + namespace_key: str, ) -> None: """Validate evaluator config for definitions referencing known global evaluators. @@ -296,7 +311,10 @@ async def _validate_control_definition( agent_data = agent_data_by_name.get(agent_namespace) if agent_data is None: agent_result = await db.execute( - select(Agent).where(Agent.name == agent_namespace) + select(Agent).where( + Agent.name == agent_namespace, + Agent.namespace_key == namespace_key, + ) ) agent = agent_result.scalars().first() if agent is None: @@ -447,7 +465,7 @@ async def _validate_control_definition( async def render_control_template( request: RenderControlTemplateRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> RenderControlTemplateResponse: """Render a template-backed control without persisting it.""" control_def = await _render_and_validate_template_input( @@ -456,6 +474,7 @@ async def render_control_template( template_values=request.template_values, ), db=db, + namespace_key=principal.namespace_key, enabled=True, ) return RenderControlTemplateResponse(control=control_def) @@ -504,7 +523,11 @@ async def create_control( hint="Choose a different name or update the existing control.", ) - control_def = await _materialize_control_input(request.data, db=db) + control_def = await _materialize_control_input( + request.data, + db=db, + namespace_key=namespace_key, + ) control_data = _serialize_control_data(control_def) control = control_service.create_control( @@ -751,6 +774,7 @@ async def set_control_data( control_def = await _materialize_control_input( request.data, db=db, + namespace_key=principal.namespace_key, current_payload=control.data, control_id=control_id, ) @@ -791,7 +815,7 @@ async def set_control_data( async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. @@ -805,7 +829,11 @@ async def validate_control_data( """ # Validate mirrors create: complete template values trigger a full render, # incomplete values validate structure only (matching unrendered create). - await _materialize_control_input(request.data, db=db) + await _materialize_control_input( + request.data, + db=db, + namespace_key=principal.namespace_key, + ) return ValidateControlDataResponse(success=True) diff --git a/server/tests/test_principal_namespace_flow.py b/server/tests/test_principal_namespace_flow.py index 40ecd216..14d2d874 100644 --- a/server/tests/test_principal_namespace_flow.py +++ b/server/tests/test_principal_namespace_flow.py @@ -3,16 +3,16 @@ from __future__ import annotations import uuid +from copy import deepcopy from typing import Any -from fastapi import FastAPI, Request -from fastapi.testclient import TestClient - from agent_control_server.auth_framework import ( Operation, Principal, set_authorizer, ) +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient from .utils import VALID_CONTROL_PAYLOAD @@ -139,3 +139,30 @@ def test_duplicate_control_names_allowed_across_principal_namespaces(app: FastAP assert ns_a.put("/api/v1/controls", json=payload).status_code == 200 assert ns_b.put("/api/v1/controls", json=payload).status_code == 200 + + +def test_agent_scoped_evaluator_validation_uses_principal_namespace(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + register_b = ns_b.post( + "/api/v1/agents/initAgent", + json={ + **_agent_payload(agent_name), + "evaluators": [{"name": "custom", "config_schema": {"type": "object"}}], + }, + ) + assert register_b.status_code == 200, register_b.text + + control_data = deepcopy(VALID_CONTROL_PAYLOAD) + control_data["condition"]["evaluator"] = { + "name": f"{agent_name}:custom", + "config": {}, + } + + resp = ns_a.post("/api/v1/controls/validate", json={"data": control_data}) + assert resp.status_code == 404, resp.text + assert resp.json()["detail"] == f"Agent '{agent_name}' not found" diff --git a/server/tests/test_target_merged_contract.py b/server/tests/test_target_merged_contract.py index 62891ba5..6bc4ab0f 100644 --- a/server/tests/test_target_merged_contract.py +++ b/server/tests/test_target_merged_contract.py @@ -18,11 +18,37 @@ from copy import deepcopy from typing import Any +import pytest +from agent_control_server.auth_framework import Operation, Principal, set_authorizer +from fastapi import Request from fastapi.testclient import TestClient from .utils import VALID_CONTROL_PAYLOAD, canonicalize_control_payload +class RecordingAuthorizer: + """Authorizer that records operation/context pairs for endpoint contract tests.""" + + def __init__(self, *, target_namespace_key: str = "default") -> None: + self.calls: list[tuple[Operation, dict[str, Any] | None]] = [] + self.target_namespace_key = target_namespace_key + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request + self.calls.append((operation, context)) + namespace_key = ( + self.target_namespace_key + if operation is Operation.CONTROL_BINDINGS_READ and context is not None + else "default" + ) + return Principal(namespace_key=namespace_key, is_admin=True) + + def _agent_payload( agent_name: str, *, @@ -115,7 +141,7 @@ def _list_effective_via_get( # --------------------------------------------------------------------------- -def test_initAgent_with_target_merges_direct_and_target_controls( +def test_init_agent_with_target_merges_direct_and_target_controls( client: TestClient, ) -> None: agent_name = f"agent-{uuid.uuid4().hex[:12]}" @@ -134,7 +160,7 @@ def test_initAgent_with_target_merges_direct_and_target_controls( assert returned_ids == {direct_id, target_id_ctrl} -def test_initAgent_newly_created_with_target_picks_up_pre_existing_bindings( +def test_init_agent_newly_created_with_target_picks_up_pre_existing_bindings( client: TestClient, ) -> None: """Bindings can pre-exist the agent row. @@ -154,7 +180,7 @@ def test_initAgent_newly_created_with_target_picks_up_pre_existing_bindings( assert returned_ids == [pre_existing] -def test_initAgent_partial_target_pair_rejected(client: TestClient) -> None: +def test_init_agent_partial_target_pair_rejected(client: TestClient) -> None: agent_name = f"agent-{uuid.uuid4().hex[:12]}" payload = _agent_payload(agent_name) payload["target_type"] = "env" # target_id omitted @@ -162,12 +188,28 @@ def test_initAgent_partial_target_pair_rejected(client: TestClient) -> None: assert resp.status_code == 422 +def test_init_agent_with_target_requires_target_read_authorization( + client: TestClient, +) -> None: + authorizer = RecordingAuthorizer() + set_authorizer(authorizer) + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + body = _register_agent(client, agent_name, target_type="env", target_id="prod") + + assert body["created"] is True + assert ( + Operation.CONTROL_BINDINGS_READ, + {"target_type": "env", "target_id": "prod"}, + ) in authorizer.calls + + # --------------------------------------------------------------------------- # GET /agents/{name}/controls contract. # --------------------------------------------------------------------------- -def test_get_agent_controls_with_target_matches_initAgent_response( +def test_get_agent_controls_with_target_matches_init_agent_response( client: TestClient, ) -> None: agent_name = f"agent-{uuid.uuid4().hex[:12]}" @@ -200,6 +242,45 @@ def test_get_agent_controls_partial_target_pair_returns_400( assert resp.status_code == 400 +def test_get_agent_controls_with_target_requires_target_read_authorization( + client: TestClient, +) -> None: + authorizer = RecordingAuthorizer() + set_authorizer(authorizer) + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + _register_agent(client, agent_name) + authorizer.calls.clear() + + ids = _list_effective_via_get( + client, + agent_name, + target_type="env", + target_id="prod", + ) + + assert ids == [] + assert (Operation.AGENTS_READ, None) in authorizer.calls + assert ( + Operation.CONTROL_BINDINGS_READ, + {"target_type": "env", "target_id": "prod"}, + ) in authorizer.calls + + +def test_get_agent_controls_rejects_target_namespace_mismatch( + client: TestClient, +) -> None: + set_authorizer(RecordingAuthorizer(target_namespace_key="other-ns")) + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + _register_agent(client, agent_name) + + resp = client.get( + f"/api/v1/agents/{agent_name}/controls", + params={"target_type": "env", "target_id": "prod"}, + ) + + assert resp.status_code == 403, resp.text + + def test_get_agent_controls_no_target_omits_target_bindings( client: TestClient, ) -> None: @@ -243,11 +324,10 @@ async def _insert_agent_in_namespace(async_db, *, name: str, namespace_key: str) await async_db.commit() -import pytest # noqa: E402 (kept local; the rest of the file is sync) - - @pytest.mark.asyncio -async def test_get_agent_controls_cross_namespace_returns_404(client: TestClient, async_db) -> None: +async def test_get_agent_controls_cross_namespace_returns_404( + client: TestClient, async_db +) -> None: """Agent existing only in another namespace must not surface here. The merged-resolver contract is namespace-scoped end-to-end; if the From 4e8e035ce02f25730ed3dcf0e99e003d8a6e336c Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 15:20:43 +0530 Subject: [PATCH 07/20] docs(server): clarify upstream auth failure mapping --- docs/auth.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/auth.md b/docs/auth.md index 5002faa8..7aafd2ad 100644 --- a/docs/auth.md +++ b/docs/auth.md @@ -120,7 +120,8 @@ Status handling: | `403` | Forbidden error. | | `404` | Not found error. | | `429` | `503` with a rate-limit detail and `Retry-After` hint when present. | -| Other statuses or malformed JSON | Fail closed with `503` or `502`. | +| Other statuses or upstream network errors | Fail closed with `503`. | +| Malformed `200` principal response | Fail closed with `502`. | ## Runtime JWT Claims From 3189369225601c80d09b3dd56d2a737c0dfbb234 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 15:50:43 +0530 Subject: [PATCH 08/20] docs(server): explain target principal authorization --- .../agent_control_server/endpoints/agents.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index 57ca1ebc..1b380026 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -121,7 +121,20 @@ async def _authorize_target_read_if_present( request: Request, context: dict[str, str] | None, ) -> Principal | None: - """Require target read authorization before returning target-merged controls.""" + """Require target read authorization before returning target-merged controls. + + Agent endpoints that accept optional target context have two separate + authorization decisions: + + - the endpoint operation itself (for example, ``agents.create``), whose + result is exposed to the route as ``principal``; + - the target binding read (``control_bindings.read``), whose result is + exposed as ``target_principal``. + + Keeping the results separate lets the route verify that the caller's + namespace and the target's resolved namespace agree before merging + target-bound controls into the response. + """ if context is None: return None return await get_authorizer(Operation.CONTROL_BINDINGS_READ).authorize( @@ -545,7 +558,8 @@ async def init_agent( Args: request: Agent metadata and step schemas db: Database session (injected) - principal: Authorized request principal + principal: Authorized request principal for the agent create operation + target_principal: Optional principal from the target binding read check Returns: InitAgentResponse with created flag and the effective controls @@ -1596,7 +1610,8 @@ async def list_agent_controls( target_type: Optional opaque target kind (paired with target_id) target_id: Optional opaque target identifier (paired with target_type) db: Database session (injected) - principal: Authorized request principal + principal: Authorized request principal for the agent read operation + target_principal: Optional principal from the target binding read check Returns: AgentControlsResponse with controls matching the requested state filters From e98207e50aa6fd0d652821b017f159810953449a Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 15:59:55 +0530 Subject: [PATCH 09/20] chore(sdk-ts): refresh generated client docs --- sdks/typescript/src/generated/funcs/agents-init.ts | 3 ++- sdks/typescript/src/generated/funcs/agents-list-controls.ts | 3 ++- .../src/generated/models/create-control-binding-request.ts | 2 +- .../src/generated/models/runtime-token-exchange-request.ts | 2 +- sdks/typescript/src/generated/sdk/agents.ts | 6 ++++-- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sdks/typescript/src/generated/funcs/agents-init.ts b/sdks/typescript/src/generated/funcs/agents-init.ts index 7150b2a4..d1136c2f 100644 --- a/sdks/typescript/src/generated/funcs/agents-init.ts +++ b/sdks/typescript/src/generated/funcs/agents-init.ts @@ -51,7 +51,8 @@ import { Result } from "../types/fp.js"; * Args: * request: Agent metadata and step schemas * db: Database session (injected) - * principal: Authorized request principal + * principal: Authorized request principal for the agent create operation + * target_principal: Optional principal from the target binding read check * * Returns: * InitAgentResponse with created flag and the effective controls diff --git a/sdks/typescript/src/generated/funcs/agents-list-controls.ts b/sdks/typescript/src/generated/funcs/agents-list-controls.ts index d1e5b27d..619a45d6 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-controls.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-controls.ts @@ -53,7 +53,8 @@ import { Result } from "../types/fp.js"; * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * principal: Authorized request principal + * principal: Authorized request principal for the agent read operation + * target_principal: Optional principal from the target binding read check * * Returns: * AgentControlsResponse with controls matching the requested state filters diff --git a/sdks/typescript/src/generated/models/create-control-binding-request.ts b/sdks/typescript/src/generated/models/create-control-binding-request.ts index ace9f49b..f4e0c940 100644 --- a/sdks/typescript/src/generated/models/create-control-binding-request.ts +++ b/sdks/typescript/src/generated/models/create-control-binding-request.ts @@ -22,7 +22,7 @@ export type CreateControlBindingRequest = { */ targetId: string; /** - * Opaque attachment kind (caller-defined; e.g. 'env', 'log_stream'). + * Opaque attachment kind (caller-defined; e.g. 'environment', 'session'). */ targetType: string; }; diff --git a/sdks/typescript/src/generated/models/runtime-token-exchange-request.ts b/sdks/typescript/src/generated/models/runtime-token-exchange-request.ts index 65e02bda..e20ed22e 100644 --- a/sdks/typescript/src/generated/models/runtime-token-exchange-request.ts +++ b/sdks/typescript/src/generated/models/runtime-token-exchange-request.ts @@ -14,7 +14,7 @@ export type RuntimeTokenExchangeRequest = { */ targetId: string; /** - * Opaque target kind (e.g., ``log_stream``). + * Opaque target kind (e.g., ``session``). */ targetType: string; }; diff --git a/sdks/typescript/src/generated/sdk/agents.ts b/sdks/typescript/src/generated/sdk/agents.ts index 0a70e128..bed5b41f 100644 --- a/sdks/typescript/src/generated/sdk/agents.ts +++ b/sdks/typescript/src/generated/sdk/agents.ts @@ -80,7 +80,8 @@ export class Agents extends ClientSDK { * Args: * request: Agent metadata and step schemas * db: Database session (injected) - * principal: Authorized request principal + * principal: Authorized request principal for the agent create operation + * target_principal: Optional principal from the target binding read check * * Returns: * InitAgentResponse with created flag and the effective controls @@ -186,7 +187,8 @@ export class Agents extends ClientSDK { * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * principal: Authorized request principal + * principal: Authorized request principal for the agent read operation + * target_principal: Optional principal from the target binding read check * * Returns: * AgentControlsResponse with controls matching the requested state filters From 068673e47db4e9f4ab7482f4629974ac18c180ba Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Tue, 12 May 2026 13:44:38 +0530 Subject: [PATCH 10/20] fix(server): sanitize jsonvalue openapi variants --- server/src/agent_control_server/main.py | 15 +++++++++++-- server/tests/test_main_lifespan.py | 28 ++++++++++++++++++------- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index ddd22195..90763e21 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -334,6 +334,16 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- ) +JSON_VALUE_SCHEMA_NAMES = ( + "JSONValue", + "JSONValue-Input", + "JSONValue-Output", + "JsonValue", + "JsonValue-Input", + "JsonValue-Output", +) + + # Override OpenAPI to avoid recursive JSONValue schema issues in TS generators. def custom_openapi() -> dict[str, Any]: if app.openapi_schema: @@ -347,8 +357,9 @@ def custom_openapi() -> dict[str, Any]: ) schemas = openapi_schema.get("components", {}).get("schemas", {}) - if "JSONValue" in schemas: - schemas["JSONValue"] = {"description": "Any JSON value"} + for schema_name in JSON_VALUE_SCHEMA_NAMES: + if schema_name in schemas: + schemas[schema_name] = {"description": "Any JSON value"} # This route is intentionally public metadata. FastAPI still emits inherited # API-key security for it, so patch only this operation in the generated spec. diff --git a/server/tests/test_main_lifespan.py b/server/tests/test_main_lifespan.py index 5a557743..e6e6f595 100644 --- a/server/tests/test_main_lifespan.py +++ b/server/tests/test_main_lifespan.py @@ -1,5 +1,8 @@ from __future__ import annotations +from fastapi import FastAPI +from fastapi.testclient import TestClient + from agent_control_server import main as main_module from agent_control_server.config import observability_settings, settings from agent_control_server.main import lifespan @@ -8,8 +11,6 @@ register_control_event_sink_factory, unregister_control_event_sink_factory, ) -from fastapi import FastAPI -from fastapi.testclient import TestClient def test_lifespan_initializes_observability_when_enabled(monkeypatch) -> None: @@ -156,11 +157,22 @@ def test_lifespan_skips_observability_when_disabled(monkeypatch) -> None: assert not hasattr(app.state, "event_ingestor") -def test_custom_openapi_replaces_jsonvalue(monkeypatch) -> None: - # Given: a custom openapi generator that includes JSONValue +def test_custom_openapi_replaces_jsonvalue_variants(monkeypatch) -> None: + # Given: a custom openapi generator that includes Pydantic JSONValue schemas + json_value_schema_names = ( + "JSONValue", + "JSONValue-Input", + "JSONValue-Output", + "JsonValue", + "JsonValue-Input", + "JsonValue-Output", + ) + def fake_get_openapi(*, title, version, description, routes): return { - "components": {"schemas": {"JSONValue": {"type": "object"}}}, + "components": { + "schemas": {name: {"type": "object"} for name in json_value_schema_names} + }, "info": {"title": title, "version": version, "description": description}, "paths": {}, } @@ -171,8 +183,10 @@ def fake_get_openapi(*, title, version, description, routes): # When: generating openapi schema = main_module.app.openapi() - # Then: JSONValue is replaced with safe description - assert schema["components"]["schemas"]["JSONValue"]["description"] == "Any JSON value" + # Then: JSONValue schemas are replaced with a non-recursive schema + schemas = schema["components"]["schemas"] + for schema_name in json_value_schema_names: + assert schemas[schema_name] == {"description": "Any JSON value"} def test_custom_openapi_is_cached(monkeypatch) -> None: From 08f85ca1ffb94015aa20e2211ac7eba9fdc79e69 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Tue, 12 May 2026 15:18:35 +0530 Subject: [PATCH 11/20] fix(server): address runtime auth review feedback --- sdks/typescript/src/generated/models/index.ts | 4 - .../src/generated/models/json-value-input.ts | 40 ---------- .../src/generated/models/json-value-input1.ts | 47 ------------ .../src/generated/models/json-value-output.ts | 44 ----------- .../generated/models/json-value-output1.ts | 48 ------------ sdks/typescript/src/generated/models/step.ts | 30 +++----- .../models/template-definition-input.ts | 14 ++-- .../models/template-definition-output.ts | 11 ++- .../auth_framework/config.py | 6 +- .../auth_framework/providers/local_jwt.py | 38 ++++------ .../endpoints/evaluation.py | 4 +- .../agent_control_server/services/controls.py | 9 ++- server/tests/test_auth_framework.py | 37 +++++++++- server/tests/test_principal_namespace_flow.py | 74 +++++++++++++++++++ .../test_runtime_token_exchange_endpoint.py | 43 ++++++++++- 15 files changed, 201 insertions(+), 248 deletions(-) delete mode 100644 sdks/typescript/src/generated/models/json-value-input.ts delete mode 100644 sdks/typescript/src/generated/models/json-value-input1.ts delete mode 100644 sdks/typescript/src/generated/models/json-value-output.ts delete mode 100644 sdks/typescript/src/generated/models/json-value-output1.ts diff --git a/sdks/typescript/src/generated/models/index.ts b/sdks/typescript/src/generated/models/index.ts index a31abbbe..595a9501 100644 --- a/sdks/typescript/src/generated/models/index.ts +++ b/sdks/typescript/src/generated/models/index.ts @@ -63,10 +63,6 @@ export * from "./init-agent-evaluator-removal.js"; export * from "./init-agent-overwrite-changes.js"; export * from "./init-agent-request.js"; export * from "./init-agent-response.js"; -export * from "./json-value-input.js"; -export * from "./json-value-input1.js"; -export * from "./json-value-output.js"; -export * from "./json-value-output1.js"; export * from "./list-agents-response.js"; export * from "./list-control-bindings-response.js"; export * from "./list-control-versions-response.js"; diff --git a/sdks/typescript/src/generated/models/json-value-input.ts b/sdks/typescript/src/generated/models/json-value-input.ts deleted file mode 100644 index 4f448073..00000000 --- a/sdks/typescript/src/generated/models/json-value-input.ts +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. - */ - -import * as z from "zod/v4-mini"; -import { smartUnion } from "../types/smart-union.js"; - -export type JSONValueInput = - | string - | number - | number - | boolean - | Array - | { [k: string]: JSONValueInput | null }; - -/** @internal */ -export type JSONValueInput$Outbound = - | string - | number - | number - | boolean - | Array - | { [k: string]: JSONValueInput$Outbound | null }; - -/** @internal */ -export const JSONValueInput$outboundSchema: z.ZodMiniType< - JSONValueInput$Outbound, - JSONValueInput -> = smartUnion([ - z.string(), - z.int(), - z.number(), - z.boolean(), - z.array(z.nullable(z.lazy(() => JSONValueInput$outboundSchema))), - z.record(z.string(), z.nullable(z.lazy(() => JSONValueInput$outboundSchema))), -]); - -export function jsonValueInputToJSON(jsonValueInput: JSONValueInput): string { - return JSON.stringify(JSONValueInput$outboundSchema.parse(jsonValueInput)); -} diff --git a/sdks/typescript/src/generated/models/json-value-input1.ts b/sdks/typescript/src/generated/models/json-value-input1.ts deleted file mode 100644 index b613f2e4..00000000 --- a/sdks/typescript/src/generated/models/json-value-input1.ts +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. - */ - -import * as z from "zod/v4-mini"; -import { smartUnion } from "../types/smart-union.js"; -import { - JSONValueInput, - JSONValueInput$Outbound, - JSONValueInput$outboundSchema, -} from "./json-value-input.js"; - -export type JsonValueInput1 = - | string - | number - | number - | boolean - | Array - | { [k: string]: JSONValueInput | null }; - -/** @internal */ -export type JsonValueInput1$Outbound = - | string - | number - | number - | boolean - | Array - | { [k: string]: JSONValueInput$Outbound | null }; - -/** @internal */ -export const JsonValueInput1$outboundSchema: z.ZodMiniType< - JsonValueInput1$Outbound, - JsonValueInput1 -> = smartUnion([ - z.string(), - z.int(), - z.number(), - z.boolean(), - z.array(z.nullable(JSONValueInput$outboundSchema)), - z.record(z.string(), z.nullable(z.lazy(() => JSONValueInput$outboundSchema))), -]); - -export function jsonValueInput1ToJSON( - jsonValueInput1: JsonValueInput1, -): string { - return JSON.stringify(JsonValueInput1$outboundSchema.parse(jsonValueInput1)); -} diff --git a/sdks/typescript/src/generated/models/json-value-output.ts b/sdks/typescript/src/generated/models/json-value-output.ts deleted file mode 100644 index f50e2790..00000000 --- a/sdks/typescript/src/generated/models/json-value-output.ts +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. - */ - -import * as z from "zod/v4-mini"; -import { safeParse } from "../lib/schemas.js"; -import { Result as SafeParseResult } from "../types/fp.js"; -import * as types from "../types/primitives.js"; -import { smartUnion } from "../types/smart-union.js"; -import { SDKValidationError } from "./errors/sdk-validation-error.js"; - -export type JSONValueOutput = - | string - | number - | number - | boolean - | Array - | { [k: string]: JSONValueOutput | null }; - -/** @internal */ -export const JSONValueOutput$inboundSchema: z.ZodMiniType< - JSONValueOutput, - unknown -> = smartUnion([ - types.string(), - types.number(), - types.number(), - types.boolean(), - z.array(types.nullable(z.lazy(() => JSONValueOutput$inboundSchema))), - z.record( - z.string(), - types.nullable(z.lazy(() => JSONValueOutput$inboundSchema)), - ), -]); - -export function jsonValueOutputFromJSON( - jsonString: string, -): SafeParseResult { - return safeParse( - jsonString, - (x) => JSONValueOutput$inboundSchema.parse(JSON.parse(x)), - `Failed to parse 'JSONValueOutput' from JSON`, - ); -} diff --git a/sdks/typescript/src/generated/models/json-value-output1.ts b/sdks/typescript/src/generated/models/json-value-output1.ts deleted file mode 100644 index 877520c3..00000000 --- a/sdks/typescript/src/generated/models/json-value-output1.ts +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. - */ - -import * as z from "zod/v4-mini"; -import { safeParse } from "../lib/schemas.js"; -import { Result as SafeParseResult } from "../types/fp.js"; -import * as types from "../types/primitives.js"; -import { smartUnion } from "../types/smart-union.js"; -import { SDKValidationError } from "./errors/sdk-validation-error.js"; -import { - JSONValueOutput, - JSONValueOutput$inboundSchema, -} from "./json-value-output.js"; - -export type JsonValueOutput1 = - | string - | number - | number - | boolean - | Array - | { [k: string]: JSONValueOutput | null }; - -/** @internal */ -export const JsonValueOutput1$inboundSchema: z.ZodMiniType< - JsonValueOutput1, - unknown -> = smartUnion([ - types.string(), - types.number(), - types.number(), - types.boolean(), - z.array(types.nullable(JSONValueOutput$inboundSchema)), - z.record( - z.string(), - types.nullable(z.lazy(() => JSONValueOutput$inboundSchema)), - ), -]); - -export function jsonValueOutput1FromJSON( - jsonString: string, -): SafeParseResult { - return safeParse( - jsonString, - (x) => JsonValueOutput1$inboundSchema.parse(JSON.parse(x)), - `Failed to parse 'JsonValueOutput1' from JSON`, - ); -} diff --git a/sdks/typescript/src/generated/models/step.ts b/sdks/typescript/src/generated/models/step.ts index 8c3d4468..132cf9c9 100644 --- a/sdks/typescript/src/generated/models/step.ts +++ b/sdks/typescript/src/generated/models/step.ts @@ -3,11 +3,6 @@ */ import * as z from "zod/v4-mini"; -import { - JSONValueInput, - JSONValueInput$Outbound, - JSONValueInput$outboundSchema, -} from "./json-value-input.js"; /** * Runtime payload for an agent step invocation. @@ -16,8 +11,11 @@ export type Step = { /** * Optional context (conversation history, metadata, etc.) */ - context?: { [k: string]: JSONValueInput | null } | null | undefined; - input: JSONValueInput | null; + context?: { [k: string]: any } | null | undefined; + /** + * Any JSON value + */ + input: any; /** * Step name (tool name or model/chain id) */ @@ -25,7 +23,7 @@ export type Step = { /** * Output content for this step (None for pre-checks) */ - output?: JSONValueInput | null | undefined; + output?: any | null | undefined; /** * Step type (e.g., 'tool', 'llm') */ @@ -34,24 +32,20 @@ export type Step = { /** @internal */ export type Step$Outbound = { - context?: { [k: string]: JSONValueInput$Outbound | null } | null | undefined; - input: JSONValueInput$Outbound | null; + context?: { [k: string]: any } | null | undefined; + input: any; name: string; - output?: JSONValueInput$Outbound | null | undefined; + output?: any | null | undefined; type: string; }; /** @internal */ export const Step$outboundSchema: z.ZodMiniType = z.object( { - context: z.optional( - z.nullable( - z.record(z.string(), z.nullable(JSONValueInput$outboundSchema)), - ), - ), - input: z.nullable(JSONValueInput$outboundSchema), + context: z.optional(z.nullable(z.record(z.string(), z.any()))), + input: z.any(), name: z.string(), - output: z.optional(z.nullable(JSONValueInput$outboundSchema)), + output: z.optional(z.nullable(z.any())), type: z.string(), }, ); diff --git a/sdks/typescript/src/generated/models/template-definition-input.ts b/sdks/typescript/src/generated/models/template-definition-input.ts index 61e40755..e27f379e 100644 --- a/sdks/typescript/src/generated/models/template-definition-input.ts +++ b/sdks/typescript/src/generated/models/template-definition-input.ts @@ -4,11 +4,6 @@ import * as z from "zod/v4-mini"; import { remap as remap$ } from "../lib/primitives.js"; -import { - JsonValueInput1, - JsonValueInput1$Outbound, - JsonValueInput1$outboundSchema, -} from "./json-value-input1.js"; import { TemplateParameterDefinition, TemplateParameterDefinition$Outbound, @@ -19,7 +14,10 @@ import { * Reusable template with typed parameters and a JSON definition template. */ export type TemplateDefinitionInput = { - definitionTemplate: JsonValueInput1 | null; + /** + * Any JSON value + */ + definitionTemplate: any; /** * Metadata describing the template itself */ @@ -32,7 +30,7 @@ export type TemplateDefinitionInput = { /** @internal */ export type TemplateDefinitionInput$Outbound = { - definition_template: JsonValueInput1$Outbound | null; + definition_template: any; description?: string | null | undefined; parameters?: | { [k: string]: TemplateParameterDefinition$Outbound } @@ -45,7 +43,7 @@ export const TemplateDefinitionInput$outboundSchema: z.ZodMiniType< TemplateDefinitionInput > = z.pipe( z.object({ - definitionTemplate: z.nullable(JsonValueInput1$outboundSchema), + definitionTemplate: z.any(), description: z.optional(z.nullable(z.string())), parameters: z.optional( z.record(z.string(), TemplateParameterDefinition$outboundSchema), diff --git a/sdks/typescript/src/generated/models/template-definition-output.ts b/sdks/typescript/src/generated/models/template-definition-output.ts index b246dd7d..15cc9140 100644 --- a/sdks/typescript/src/generated/models/template-definition-output.ts +++ b/sdks/typescript/src/generated/models/template-definition-output.ts @@ -8,10 +8,6 @@ import { safeParse } from "../lib/schemas.js"; import { Result as SafeParseResult } from "../types/fp.js"; import * as types from "../types/primitives.js"; import { SDKValidationError } from "./errors/sdk-validation-error.js"; -import { - JsonValueOutput1, - JsonValueOutput1$inboundSchema, -} from "./json-value-output1.js"; import { TemplateParameterDefinition, TemplateParameterDefinition$inboundSchema, @@ -21,7 +17,10 @@ import { * Reusable template with typed parameters and a JSON definition template. */ export type TemplateDefinitionOutput = { - definitionTemplate: JsonValueOutput1 | null; + /** + * Any JSON value + */ + definitionTemplate: any; /** * Metadata describing the template itself */ @@ -38,7 +37,7 @@ export const TemplateDefinitionOutput$inboundSchema: z.ZodMiniType< unknown > = z.pipe( z.object({ - definition_template: types.nullable(JsonValueOutput1$inboundSchema), + definition_template: z.any(), description: z.optional(z.nullable(types.string())), parameters: types.optional( z.record(z.string(), TemplateParameterDefinition$inboundSchema), diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 595c3117..73852248 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -218,7 +218,8 @@ def _build_default_provider() -> RequestAuthorizer: ) ) raise RuntimeError( - f"Unknown {_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'http_upstream'." + f"Unknown {_MODE_ENV}={mode!r}; expected 'none', 'api_key', 'header', " + "or 'http_upstream'." ) @@ -259,7 +260,8 @@ def _resolve_runtime_mode() -> str: if mode == "jwt": return mode raise RuntimeError( - f"Unknown {_RUNTIME_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'jwt'." + f"Unknown {_RUNTIME_MODE_ENV}={mode!r}; expected 'none', 'api_key', " + "'header', or 'jwt'." ) diff --git a/server/src/agent_control_server/auth_framework/providers/local_jwt.py b/server/src/agent_control_server/auth_framework/providers/local_jwt.py index 8620d3b6..3f39e6fd 100644 --- a/server/src/agent_control_server/auth_framework/providers/local_jwt.py +++ b/server/src/agent_control_server/auth_framework/providers/local_jwt.py @@ -4,9 +4,8 @@ ``Authorization`` header, verifies the signature against the runtime secret, checks the token's scope covers the requested operation, and returns a :class:`Principal` carrying the bound target. When a -``context_builder`` on the dependency surfaces ``target_type`` / -``target_id``, the provider also enforces that they match the token's -binding - runtime endpoints get the request-target check for free. +``context_builder`` on the dependency must surface matching +``target_type`` / ``target_id`` values for target-bound tokens. """ from __future__ import annotations @@ -55,25 +54,20 @@ async def authorize( hint="Request a token with the required scope.", ) - if context is not None: - requested_target_type = context.get("target_type") - requested_target_id = context.get("target_id") - if requested_target_type is not None and requested_target_type != claims.target_type: - raise ForbiddenError( - error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, - detail=( - "Runtime token target_type does not match the request." - ), - hint="Re-exchange a token bound to the request target.", - ) - if requested_target_id is not None and requested_target_id != claims.target_id: - raise ForbiddenError( - error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, - detail=( - "Runtime token target_id does not match the request." - ), - hint="Re-exchange a token bound to the request target.", - ) + requested_target_type = context.get("target_type") if context is not None else None + requested_target_id = context.get("target_id") if context is not None else None + if requested_target_type != claims.target_type: + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="Runtime token target_type does not match the request.", + hint="Re-exchange a token bound to the request target.", + ) + if requested_target_id != claims.target_id: + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="Runtime token target_id does not match the request.", + hint="Re-exchange a token bound to the request target.", + ) return Principal( namespace_key=claims.namespace_key, diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index 437af8b5..30779c5c 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -1,5 +1,6 @@ """Evaluation analysis endpoints.""" +import json from dataclasses import dataclass from agent_control_engine.core import ControlEngine @@ -121,7 +122,8 @@ async def _evaluation_context(request: Request) -> dict[str, object]: """Surface target identifiers to the runtime authorizer.""" try: body = await request.json() - except Exception: # noqa: BLE001 malformed JSON, defer to endpoint validation + except (json.JSONDecodeError, UnicodeDecodeError): + _logger.debug("Unable to decode evaluation request body for auth context") return {} if not isinstance(body, dict): return {} diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index e3a5fd26..6c015310 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -134,13 +134,14 @@ async def get_control_or_404( self, control_id: int, *, - namespace_key: str | None = None, + namespace_key: str, for_update: bool = False, ) -> Control: """Load any control row, including soft-deleted controls.""" - stmt = select(Control).where(Control.id == control_id) - if namespace_key is not None: - stmt = stmt.where(Control.namespace_key == namespace_key) + stmt = select(Control).where( + Control.id == control_id, + Control.namespace_key == namespace_key, + ) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 20c58aed..06f1be89 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -742,7 +742,11 @@ async def test_local_jwt_provider_returns_target_bound_principal(): provider = LocalJwtVerifyProvider(secret=_TEST_SECRET) request = _build_request(headers={"Authorization": f"Bearer {token}"}) - principal = await provider.authorize(request, Operation.RUNTIME_USE) + principal = await provider.authorize( + request, + Operation.RUNTIME_USE, + context={"target_type": "log_stream", "target_id": "ls-42"}, + ) assert principal.target_type == "log_stream" assert principal.target_id == "ls-42" @@ -815,10 +819,39 @@ async def test_local_jwt_provider_carries_token_namespace_to_principal(): provider = LocalJwtVerifyProvider(secret=_TEST_SECRET) request = _build_request(headers={"Authorization": f"Bearer {token}"}) - principal = await provider.authorize(request, Operation.RUNTIME_USE) + principal = await provider.authorize( + request, + Operation.RUNTIME_USE, + context={"target_type": "log_stream", "target_id": "ls"}, + ) assert principal.namespace_key == "org-7" +@pytest.mark.asyncio +async def test_local_jwt_provider_rejects_missing_target_context(): + """A target-bound runtime token requires matching request target context.""" + from agent_control_server.auth_framework.providers import LocalJwtVerifyProvider + from agent_control_server.auth_framework.runtime_token import ( + mint_runtime_token, + ) + from agent_control_server.errors import ForbiddenError + + token, _ = mint_runtime_token( + namespace_key="default", + actor_id="a", + target_type="log_stream", + target_id="bound-target", + scopes=("runtime.use",), + secret=_TEST_SECRET, + ttl_seconds=60, + ) + provider = LocalJwtVerifyProvider(secret=_TEST_SECRET) + request = _build_request(headers={"Authorization": f"Bearer {token}"}) + + with pytest.raises(ForbiddenError, match="target_type does not match"): + await provider.authorize(request, Operation.RUNTIME_USE) + + @pytest.mark.asyncio async def test_local_jwt_provider_enforces_target_context_match(): """When the dependency surfaces a target context, the provider enforces it.""" diff --git a/server/tests/test_principal_namespace_flow.py b/server/tests/test_principal_namespace_flow.py index 14d2d874..0ca1bca8 100644 --- a/server/tests/test_principal_namespace_flow.py +++ b/server/tests/test_principal_namespace_flow.py @@ -129,6 +129,80 @@ def test_principal_namespace_scopes_management_and_runtime(app: FastAPI) -> None assert eval_b.json()["is_safe"] is True +def test_principal_namespace_scopes_cross_namespace_writes(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + assert ns_a.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)).status_code == 200 + assert ns_b.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)).status_code == 200 + + create_control = ns_a.put( + "/api/v1/controls", + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + assert create_control.status_code == 200, create_control.text + control_id = int(create_control.json()["control_id"]) + + policy = ns_a.put( + "/api/v1/policies", + json={"name": f"policy-{uuid.uuid4().hex[:12]}"}, + ) + assert policy.status_code == 200, policy.text + policy_id = int(policy.json()["policy_id"]) + + binding = ns_a.put( + "/api/v1/control-bindings/by-key", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": True, + }, + ) + assert binding.status_code == 200, binding.text + + assert ns_b.patch(f"/api/v1/controls/{control_id}", json={"enabled": False}).status_code == 404 + assert ( + ns_b.put( + f"/api/v1/controls/{control_id}/data", + json={"data": VALID_CONTROL_PAYLOAD}, + ).status_code + == 404 + ) + assert ( + ns_b.put( + "/api/v1/control-bindings/by-key", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": False, + }, + ).status_code + == 404 + ) + delete_binding = ns_b.post( + "/api/v1/control-bindings/by-key:delete", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + }, + ) + assert delete_binding.status_code == 200, delete_binding.text + assert delete_binding.json()["deleted"] is False + assert ns_a.get("/api/v1/control-bindings").json()["bindings"] + + assert ns_b.post(f"/api/v1/agents/{agent_name}/policies/{policy_id}").status_code == 404 + assert ns_b.post(f"/api/v1/agents/{agent_name}/controls/{control_id}").status_code == 404 + + def test_duplicate_control_names_allowed_across_principal_namespaces(app: FastAPI) -> None: set_authorizer(HeaderNamespaceAuthorizer()) diff --git a/server/tests/test_runtime_token_exchange_endpoint.py b/server/tests/test_runtime_token_exchange_endpoint.py index 1b1edae2..0863c9a0 100644 --- a/server/tests/test_runtime_token_exchange_endpoint.py +++ b/server/tests/test_runtime_token_exchange_endpoint.py @@ -172,7 +172,11 @@ async def test_exchange_then_verify_full_round_trip(client: TestClient, runtime_ verify_provider = LocalJwtVerifyProvider(secret=_TEST_SECRET) request = MagicMock() request.headers = {"Authorization": f"Bearer {token}"} - principal = await verify_provider.authorize(request, Operation.RUNTIME_USE) + principal = await verify_provider.authorize( + request, + Operation.RUNTIME_USE, + context={"target_type": "log_stream", "target_id": "ls-99"}, + ) assert principal.target_type == "log_stream" assert principal.target_id == "ls-99" @@ -212,6 +216,37 @@ def test_evaluation_rejects_runtime_jwt_for_wrong_target( assert response.json()["detail"] == "Runtime token target_id does not match the request." +def test_evaluation_rejects_runtime_jwt_without_bound_target_context( + client: TestClient, + runtime_config_enabled, +): + """A target-bound runtime JWT must not authorize a target-less evaluation.""" + stub = _StubExchangeAuthorizer(actor_id="actor-rt", scopes=("runtime.use",)) + clear_authorizers() + set_authorizer(stub) + set_authorizer(LocalJwtVerifyProvider(secret=_TEST_SECRET), operation=Operation.RUNTIME_USE) + + exchange = client.post( + "/api/v1/auth/runtime-token-exchange", + json={"target_type": "log_stream", "target_id": "ls-allowed"}, + ) + assert exchange.status_code == 200, exchange.text + token = exchange.json()["token"] + + response = client.post( + "/api/v1/evaluation", + headers={"Authorization": f"Bearer {token}"}, + json={ + "agent_name": "agent", + "step": {"type": "llm", "name": "step", "input": "hello"}, + "stage": "pre", + }, + ) + + assert response.status_code == 403, response.text + assert response.json()["detail"] == "Runtime token target_type does not match the request." + + def test_exchange_endpoint_502_when_upstream_grant_already_expired( client: TestClient, runtime_config_enabled, @@ -316,7 +351,11 @@ async def authorize(self, request, operation, context=None): verify_provider = LocalJwtVerifyProvider(secret=_TEST_SECRET) req = MagicMock() req.headers = {"Authorization": f"Bearer {token}"} - principal = await verify_provider.authorize(req, Operation.RUNTIME_USE) + principal = await verify_provider.authorize( + req, + Operation.RUNTIME_USE, + context={"target_type": "log_stream", "target_id": "ls-org-a"}, + ) assert principal.namespace_key == "org-A" assert principal.target_id == "ls-org-a" From eeaba0b1559425988e255a74fced1554f5daba47 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Tue, 12 May 2026 15:48:33 +0530 Subject: [PATCH 12/20] fix(server): route evaluator and observability auth through framework --- docs/auth.md | 3 + .../auth_framework/core.py | 6 +- .../auth_framework/providers/header.py | 3 + .../endpoints/evaluators.py | 5 +- .../endpoints/observability.py | 35 ++++++-- server/src/agent_control_server/main.py | 7 +- server/tests/test_auth.py | 78 ++++++++++++----- server/tests/test_observability_endpoints.py | 83 ++++++++++++++++--- 8 files changed, 170 insertions(+), 50 deletions(-) diff --git a/docs/auth.md b/docs/auth.md index 7aafd2ad..9d2f6efd 100644 --- a/docs/auth.md +++ b/docs/auth.md @@ -17,6 +17,9 @@ policies.update agents.read agents.create agents.update +evaluators.read +observability.read +observability.write control_bindings.read control_bindings.write runtime.token_exchange diff --git a/server/src/agent_control_server/auth_framework/core.py b/server/src/agent_control_server/auth_framework/core.py index 058169de..011c62de 100644 --- a/server/src/agent_control_server/auth_framework/core.py +++ b/server/src/agent_control_server/auth_framework/core.py @@ -55,6 +55,9 @@ class Operation(StrEnum): AGENTS_READ = "agents.read" AGENTS_CREATE = "agents.create" AGENTS_UPDATE = "agents.update" + EVALUATORS_READ = "evaluators.read" + OBSERVABILITY_READ = "observability.read" + OBSERVABILITY_WRITE = "observability.write" RUNTIME_USE = "runtime.use" @@ -109,8 +112,7 @@ async def authorize( request: Request, operation: Operation, context: dict[str, Any] | None = None, - ) -> Principal: - ... + ) -> Principal: ... _default_authorizer: RequestAuthorizer | None = None diff --git a/server/src/agent_control_server/auth_framework/providers/header.py b/server/src/agent_control_server/auth_framework/providers/header.py index 16760768..2d917d91 100644 --- a/server/src/agent_control_server/auth_framework/providers/header.py +++ b/server/src/agent_control_server/auth_framework/providers/header.py @@ -48,6 +48,9 @@ class AccessLevel(Enum): Operation.AGENTS_READ: AccessLevel.AUTHENTICATED, Operation.AGENTS_CREATE: AccessLevel.AUTHENTICATED, Operation.AGENTS_UPDATE: AccessLevel.ADMIN, + Operation.EVALUATORS_READ: AccessLevel.AUTHENTICATED, + Operation.OBSERVABILITY_READ: AccessLevel.AUTHENTICATED, + Operation.OBSERVABILITY_WRITE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_TOKEN_EXCHANGE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_USE: AccessLevel.AUTHENTICATED, } diff --git a/server/src/agent_control_server/endpoints/evaluators.py b/server/src/agent_control_server/endpoints/evaluators.py index a9cdaa2a..6bbeddfc 100644 --- a/server/src/agent_control_server/endpoints/evaluators.py +++ b/server/src/agent_control_server/endpoints/evaluators.py @@ -3,9 +3,11 @@ from typing import Any from agent_control_engine import list_evaluators -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel, Field +from ..auth_framework import Operation, require_operation + router = APIRouter(prefix="/evaluators", tags=["evaluators"]) @@ -25,6 +27,7 @@ class EvaluatorInfo(BaseModel): response_model=dict[str, EvaluatorInfo], summary="List available evaluators", response_description="Dictionary of evaluator name to evaluator info", + dependencies=[Depends(require_operation(Operation.EVALUATORS_READ))], ) async def get_evaluators() -> dict[str, EvaluatorInfo]: """List all available evaluators. diff --git a/server/src/agent_control_server/endpoints/observability.py b/server/src/agent_control_server/endpoints/observability.py index 5de90c0a..3296ca1c 100644 --- a/server/src/agent_control_server/endpoints/observability.py +++ b/server/src/agent_control_server/endpoints/observability.py @@ -5,7 +5,7 @@ 2. Event queries (POST /events/query) - Query raw events by trace_id, etc. 3. Stats (GET /stats) - Aggregated statistics for dashboards -All endpoints require API key authentication. +All endpoints declare operation-based auth dependencies. Dependencies are stored on app.state during server lifespan (see main.py): - app.state.event_ingestor: EventIngestor @@ -27,7 +27,7 @@ ) from fastapi import APIRouter, Depends, Request -from ..auth import require_api_key +from ..auth_framework import Operation, require_operation from ..observability.ingest.base import EventIngestor from ..observability.store.base import ( EventStore, @@ -42,7 +42,6 @@ router = APIRouter( prefix="/observability", tags=["observability"], - dependencies=[Depends(require_api_key)], ) @@ -72,7 +71,12 @@ def get_event_store(request: Request) -> EventStore: # ============================================================================= -@router.post("/events", status_code=202, response_model=BatchEventsResponse) +@router.post( + "/events", + status_code=202, + response_model=BatchEventsResponse, + dependencies=[Depends(require_operation(Operation.OBSERVABILITY_WRITE))], +) async def ingest_events( request: BatchEventsRequest, ingestor: EventIngestor = Depends(get_event_ingestor), @@ -121,7 +125,11 @@ async def ingest_events( # ============================================================================= -@router.post("/events/query", response_model=EventQueryResponse) +@router.post( + "/events/query", + response_model=EventQueryResponse, + dependencies=[Depends(require_operation(Operation.OBSERVABILITY_READ))], +) async def query_events( request: EventQueryRequest, store: EventStore = Depends(get_event_store), @@ -158,7 +166,11 @@ async def query_events( # ============================================================================= -@router.get("/stats", response_model=StatsResponse) +@router.get( + "/stats", + response_model=StatsResponse, + dependencies=[Depends(require_operation(Operation.OBSERVABILITY_READ))], +) async def get_stats( agent_name: str, time_range: TimeRange = "5m", @@ -207,7 +219,11 @@ async def get_stats( ) -@router.get("/stats/controls/{control_id}", response_model=ControlStatsResponse) +@router.get( + "/stats/controls/{control_id}", + response_model=ControlStatsResponse, + dependencies=[Depends(require_operation(Operation.OBSERVABILITY_READ))], +) async def get_control_stats( control_id: int, agent_name: str, @@ -266,7 +282,10 @@ async def get_control_stats( # ============================================================================= -@router.get("/status") +@router.get( + "/status", + dependencies=[Depends(require_operation(Operation.OBSERVABILITY_READ))], +) async def get_status(request: Request) -> dict: """ Get observability system status. diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index 90763e21..89a6275d 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -17,7 +17,7 @@ from starlette_exporter import PrometheusMiddleware, handle_metrics from . import __version__ as server_version -from .auth import get_api_key_from_header, require_api_key +from .auth import get_api_key_from_header from .config import observability_settings, settings from .db import AsyncSessionLocal from .endpoints.agents import router as agent_router @@ -314,17 +314,16 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- dependencies=[Depends(get_api_key_from_header)], ) -# Evaluator discovery still uses the local credential dependency. app.include_router( evaluator_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) -# Observability routes (already has auth dependency in router) app.include_router( observability_router, prefix=api_v1_prefix, + dependencies=[Depends(get_api_key_from_header)], ) # System routes (config, login, logout) - no auth required diff --git a/server/tests/test_auth.py b/server/tests/test_auth.py index 44f2de27..fba5088c 100644 --- a/server/tests/test_auth.py +++ b/server/tests/test_auth.py @@ -1,16 +1,36 @@ """Tests for API key authentication.""" import uuid +from typing import Any import pytest +from fastapi import Request from fastapi.testclient import TestClient from agent_control_server import __version__ as server_version +from agent_control_server.auth_framework import Operation, Principal, set_authorizer from agent_control_server.config import auth_settings from .utils import VALID_CONTROL_PAYLOAD +class _RecordingAuthorizer: + """Test authorizer that records the operation requested by a route.""" + + def __init__(self) -> None: + self.calls: list[tuple[Operation, dict[str, Any] | None]] = [] + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request + self.calls.append((operation, context)) + return Principal(namespace_key="default") + + class TestHealthEndpoint: """Health endpoint should always be accessible without authentication.""" @@ -40,9 +60,7 @@ class TestProtectedEndpoints: def test_missing_api_key_returns_401(self, unauthenticated_client: TestClient) -> None: """Given no API key, when requesting protected endpoint, then returns 401.""" # When: - response = unauthenticated_client.get( - "/api/v1/agents/00000000-0000-0000-0000-000000000000" - ) + response = unauthenticated_client.get("/api/v1/agents/00000000-0000-0000-0000-000000000000") # Then: assert response.status_code == 401 @@ -111,6 +129,20 @@ def test_missing_key_returns_401_on_evaluators( # Then: assert response.status_code == 401 + def test_evaluators_use_auth_framework_provider(self, app: object) -> None: + """Given a custom authorizer, when listing evaluators, then route uses it.""" + # Given: + authorizer = _RecordingAuthorizer() + set_authorizer(authorizer) + client = TestClient(app, raise_server_exceptions=True) + + # When: + response = client.get("/api/v1/evaluators") + + # Then: + assert response.status_code == 200 + assert authorizer.calls == [(Operation.EVALUATORS_READ, None)] + class TestAuthDisabled: """When auth is disabled, all requests should succeed.""" @@ -120,21 +152,15 @@ def disable_auth(self, monkeypatch: pytest.MonkeyPatch) -> None: """Disable auth for tests in this class.""" monkeypatch.setattr(auth_settings, "api_key_enabled", False) - def test_no_key_allowed_when_disabled( - self, unauthenticated_client: TestClient - ) -> None: + def test_no_key_allowed_when_disabled(self, unauthenticated_client: TestClient) -> None: """Given auth disabled, when requesting without API key, then request succeeds.""" # When: - response = unauthenticated_client.get( - "/api/v1/agents/00000000-0000-0000-0000-000000000000" - ) + response = unauthenticated_client.get("/api/v1/agents/00000000-0000-0000-0000-000000000000") # Then: (404 for non-existent resource, but NOT 401) assert response.status_code == 404 - def test_evaluators_accessible_when_disabled( - self, unauthenticated_client: TestClient - ) -> None: + def test_evaluators_accessible_when_disabled(self, unauthenticated_client: TestClient) -> None: """Given auth disabled, when listing evaluators without API key, then returns 200.""" # When: response = unauthenticated_client.get("/api/v1/evaluators") @@ -264,9 +290,7 @@ def test_admin_key_allowed_on_representative_mutations(self, admin_client: TestC init_response = admin_client.post("/api/v1/agents/initAgent", json=init_payload) assert init_response.status_code == 200 - set_policy_response = admin_client.post( - f"/api/v1/agents/{agent_name}/policy/{policy_id}" - ) + set_policy_response = admin_client.post(f"/api/v1/agents/{agent_name}/policy/{policy_id}") assert set_policy_response.status_code == 200 @@ -344,9 +368,7 @@ def setup_no_keys(self, monkeypatch: pytest.MonkeyPatch) -> None: def test_misconfigured_returns_500(self, unauthenticated_client: TestClient) -> None: """Given auth enabled but no keys configured, when requesting, then returns 500.""" # When: - response = unauthenticated_client.get( - "/api/v1/agents/00000000-0000-0000-0000-000000000000" - ) + response = unauthenticated_client.get("/api/v1/agents/00000000-0000-0000-0000-000000000000") # Then: assert response.status_code == 500 @@ -360,6 +382,7 @@ class TestOptionalApiKey: def _make_optional_app(self) -> TestClient: from fastapi import Depends, FastAPI + from agent_control_server.auth import optional_api_key app = FastAPI() @@ -374,7 +397,9 @@ def maybe_auth(client=Depends(optional_api_key)) -> dict[str, object]: return TestClient(app) - def test_optional_api_key_auth_disabled_returns_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_optional_api_key_auth_disabled_returns_none( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Given: auth disabled monkeypatch.setattr(auth_settings, "api_key_enabled", False) @@ -386,7 +411,9 @@ def test_optional_api_key_auth_disabled_returns_none(self, monkeypatch: pytest.M assert response.status_code == 200 assert response.json()["auth"] is False - def test_optional_api_key_missing_header_returns_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_optional_api_key_missing_header_returns_none( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Given: auth enabled with configured keys monkeypatch.setattr(auth_settings, "api_key_enabled", True) monkeypatch.setattr(auth_settings, "api_keys", "user-key") @@ -402,7 +429,9 @@ def test_optional_api_key_missing_header_returns_none(self, monkeypatch: pytest. assert response.status_code == 200 assert response.json()["auth"] is False - def test_optional_api_key_invalid_header_returns_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_optional_api_key_invalid_header_returns_none( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Given: auth enabled with configured keys monkeypatch.setattr(auth_settings, "api_key_enabled", True) monkeypatch.setattr(auth_settings, "api_keys", "user-key") @@ -418,7 +447,9 @@ def test_optional_api_key_invalid_header_returns_none(self, monkeypatch: pytest. assert response.status_code == 200 assert response.json()["auth"] is False - def test_optional_api_key_admin_header_sets_admin(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_optional_api_key_admin_header_sets_admin( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Given: auth enabled with admin key monkeypatch.setattr(auth_settings, "api_key_enabled", True) monkeypatch.setattr(auth_settings, "api_keys", "user-key") @@ -449,6 +480,7 @@ def test_require_admin_key_rejects_non_admin( # When: requiring admin key on an endpoint from fastapi import Depends, FastAPI + from agent_control_server.auth import require_admin_key local_app = FastAPI() @@ -483,6 +515,7 @@ def test_authenticated_client_key_id_masks_short_key(self) -> None: def test_get_api_key_from_header_extracts_value(self) -> None: # Given: a route that returns raw API key header from fastapi import Depends, FastAPI + from agent_control_server.auth import get_api_key_from_header app = FastAPI() @@ -503,6 +536,7 @@ def raw_key(key: str | None = Depends(get_api_key_from_header)) -> dict[str, str def test_get_api_key_from_header_allows_missing(self) -> None: # Given: a route that returns raw API key header from fastapi import Depends, FastAPI + from agent_control_server.auth import get_api_key_from_header app = FastAPI() diff --git a/server/tests/test_observability_endpoints.py b/server/tests/test_observability_endpoints.py index 476cf00c..97fc0f7c 100644 --- a/server/tests/test_observability_endpoints.py +++ b/server/tests/test_observability_endpoints.py @@ -2,17 +2,20 @@ import json from datetime import datetime, timedelta, timezone +from typing import Any from uuid import UUID, uuid4 import pytest -from fastapi.testclient import TestClient -from sqlalchemy import text - from agent_control_models import ( BatchEventsRequest, ControlExecutionEvent, EventQueryRequest, ) +from fastapi import Request +from fastapi.testclient import TestClient +from sqlalchemy import text + +from agent_control_server.auth_framework import Operation, Principal, set_authorizer from agent_control_server.main import app from agent_control_server.observability.ingest.base import IngestResult @@ -42,6 +45,64 @@ def create_test_event( ) +class _RecordingAuthorizer: + """Test authorizer that records the operation requested by a route.""" + + def __init__(self) -> None: + self.calls: list[tuple[Operation, dict[str, Any] | None]] = [] + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request + self.calls.append((operation, context)) + return Principal(namespace_key="default") + + +class TestObservabilityAuthFramework: + """Tests observability routes declare operation-based authorization.""" + + def test_status_uses_read_operation(self, app: object) -> None: + """Given a custom authorizer, when getting status, then read is authorized.""" + # Given: + authorizer = _RecordingAuthorizer() + set_authorizer(authorizer) + client = TestClient(app, raise_server_exceptions=True) + + # When: + response = client.get("/api/v1/observability/status") + + # Then: + assert response.status_code == 200 + assert authorizer.calls == [(Operation.OBSERVABILITY_READ, None)] + + def test_ingest_events_uses_write_operation( + self, + app: object, + setup_observability: object, + ) -> None: + """Given a custom authorizer, when ingesting events, then write is authorized.""" + # Given: + _ = setup_observability + authorizer = _RecordingAuthorizer() + set_authorizer(authorizer) + client = TestClient(app, raise_server_exceptions=True) + request = BatchEventsRequest(events=[create_test_event()]) + + # When: + response = client.post( + "/api/v1/observability/events", + json=request.model_dump(mode="json"), + ) + + # Then: + assert response.status_code == 202 + assert authorizer.calls == [(Operation.OBSERVABILITY_WRITE, None)] + + class TestEventIngestion: """Tests for POST /events endpoint.""" @@ -155,7 +216,7 @@ def test_event_with_all_fields(self): event = ControlExecutionEvent( trace_id="a" * 32, span_id="b" * 16, - agent_name="test-agent", + agent_name="test-agent", control_id=1, control_name="test-control", check_stage="post", @@ -441,9 +502,7 @@ async def test_timeseries_aggregates_events_per_bucket( total_exec = sum(b["execution_count"] for b in buckets_with_events) total_match = sum(b["match_count"] for b in buckets_with_events) total_non_match = sum(b["non_match_count"] for b in buckets_with_events) - total_observe = sum( - b["action_counts"].get("observe", 0) for b in buckets_with_events - ) + total_observe = sum(b["action_counts"].get("observe", 0) for b in buckets_with_events) total_deny = sum(b["action_counts"].get("deny", 0) for b in buckets_with_events) assert total_exec == 3 @@ -453,9 +512,7 @@ async def test_timeseries_aggregates_events_per_bucket( assert total_deny == 1 @pytest.mark.asyncio - async def test_timeseries_empty_buckets_included( - self, client: TestClient, setup_observability - ): + async def test_timeseries_empty_buckets_included(self, client: TestClient, setup_observability): """Empty buckets are included with zero counts.""" store = setup_observability agent_name = f"agent-{uuid4().hex[:12]}" @@ -594,9 +651,7 @@ async def test_control_stats_with_timeseries(self, client: TestClient, setup_obs assert data["stats"]["execution_count"] == 2 # Sum timeseries buckets should equal total - total_from_buckets = sum( - b["execution_count"] for b in data["stats"]["timeseries"] - ) + total_from_buckets = sum(b["execution_count"] for b in data["stats"]["timeseries"]) assert total_from_buckets == 2 @pytest.mark.asyncio @@ -797,6 +852,7 @@ class TestObservabilityIngestStatus: def test_ingest_events_partial_status(self, client: TestClient, setup_observability): """Test partial status when some events are dropped.""" + # Given: a stub ingestor that drops some events class StubIngestor: async def ingest(self, events): @@ -826,6 +882,7 @@ async def ingest(self, events): def test_ingest_events_failed_status(self, client: TestClient, setup_observability): """Test failed status when all events are dropped.""" + # Given: a stub ingestor that drops all events class StubIngestor: async def ingest(self, events): From e35a8a22650a046d76da47ecb1197a69e4942e81 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 14 May 2026 23:33:44 +0530 Subject: [PATCH 13/20] fix(server): harden auth and namespace edges --- docs/auth.md | 7 +- ...c2d8e9a1_namespace_observability_events.py | 44 +++++++ .../auth_framework/config.py | 29 ++++- .../auth_framework/providers/http_upstream.py | 47 ++++++- .../auth_framework/runtime_token.py | 6 + .../agent_control_server/endpoints/agents.py | 21 ++++ .../agent_control_server/endpoints/auth.py | 7 +- .../endpoints/evaluation.py | 11 +- .../endpoints/observability.py | 19 +-- server/src/agent_control_server/migrate.py | 47 ++++--- server/src/agent_control_server/models.py | 10 +- .../observability/ingest/base.py | 10 +- .../observability/ingest/direct.py | 17 ++- .../observability/sinks.py | 10 +- .../observability/store/base.py | 20 ++- .../observability/store/postgres.py | 41 ++++-- server/tests/test_auth_framework.py | 118 ++++++++++++++++++ server/tests/test_init_agent_conflict_mode.py | 59 +++++++++ server/tests/test_migrate.py | 33 +++++ .../tests/test_observability_direct_ingest.py | 42 +++++-- server/tests/test_observability_endpoints.py | 70 +++++++++-- .../test_observability_store_postgres.py | 47 ++++++- .../test_runtime_token_exchange_endpoint.py | 36 +++++- 23 files changed, 671 insertions(+), 80 deletions(-) create mode 100644 server/alembic/versions/b6f4c2d8e9a1_namespace_observability_events.py create mode 100644 server/tests/test_migrate.py diff --git a/docs/auth.md b/docs/auth.md index 9d2f6efd..c738360b 100644 --- a/docs/auth.md +++ b/docs/auth.md @@ -51,9 +51,11 @@ Management auth is selected by `AGENT_CONTROL_AUTH_MODE`. | Mode | Meaning | | --- | --- | | `none` | No credentials required. Intended for local development only. | -| `api_key` | Validate caller credentials locally with `AGENT_CONTROL_API_KEYS`. This is the default. `header` is accepted as a backwards-compatible alias. | +| `api_key` | Validate caller credentials locally with `AGENT_CONTROL_API_KEYS` and/or `AGENT_CONTROL_ADMIN_API_KEYS`. Requires `AGENT_CONTROL_API_KEY_ENABLED=true`. `header` is accepted as a backwards-compatible alias. | | `http_upstream` | POST each management authorization decision to `AGENT_CONTROL_AUTH_UPSTREAM_URL`. | +When `AGENT_CONTROL_AUTH_MODE` is unset, startup selects `api_key` if local API-key validation is enabled and `none` otherwise. + Runtime auth is selected by `AGENT_CONTROL_RUNTIME_AUTH_MODE`. | Mode | Meaning | @@ -68,7 +70,7 @@ Common combinations: | Management | Runtime | Use case | | --- | --- | --- | | `api_key` | unset | Existing standalone deployments. | -| `api_key` | `jwt` | Local management keys with short-lived target-bound runtime tokens. | +| `api_key` | `jwt` | Local management keys with short-lived target-bound runtime tokens. This does not perform per-target authorization; any valid local API key can exchange for any target in the local namespace. | | `http_upstream` | `jwt` | External identity or authorization service for management, local token verify for high-volume runtime calls. | | `none` | `none` | Single-process local development. Do not use in production. | @@ -125,6 +127,7 @@ Status handling: | `429` | `503` with a rate-limit detail and `Retry-After` hint when present. | | Other statuses or upstream network errors | Fail closed with `503`. | | Malformed `200` principal response | Fail closed with `502`. | +| `200` target grant that conflicts with request context | Fail closed with `403`. | ## Runtime JWT Claims diff --git a/server/alembic/versions/b6f4c2d8e9a1_namespace_observability_events.py b/server/alembic/versions/b6f4c2d8e9a1_namespace_observability_events.py new file mode 100644 index 00000000..fe769116 --- /dev/null +++ b/server/alembic/versions/b6f4c2d8e9a1_namespace_observability_events.py @@ -0,0 +1,44 @@ +"""namespace observability events + +Revision ID: b6f4c2d8e9a1 +Revises: a7f3b1e0d9c5 +Create Date: 2026-05-14 12:00:00.000000 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "b6f4c2d8e9a1" +down_revision = "a7f3b1e0d9c5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "control_execution_events", + sa.Column( + "namespace_key", + sa.String(length=255), + server_default=sa.text("'default'"), + nullable=False, + ), + ) + op.create_index( + "ix_events_namespace_agent_time", + "control_execution_events", + ["namespace_key", "agent_name", sa.literal_column("timestamp DESC")], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index( + "ix_events_namespace_agent_time", + table_name="control_execution_events", + ) + op.drop_column("control_execution_events", "namespace_key") diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 73852248..559ba425 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -28,6 +28,7 @@ import os from dataclasses import dataclass +from ..config import auth_settings from ..logging_utils import get_logger from .core import Operation, RequestAuthorizer, clear_authorizers, set_authorizer from .providers import ( @@ -88,8 +89,10 @@ def configure_auth_from_env() -> None: Default flow: - ``AGENT_CONTROL_AUTH_MODE=none``: :class:`NoAuthProvider`. - - ``AGENT_CONTROL_AUTH_MODE=api_key`` (default): :class:`HeaderAuthProvider`. - ``header`` remains accepted as a backwards-compatible alias. + - ``AGENT_CONTROL_AUTH_MODE=api_key``: :class:`HeaderAuthProvider`. + ``header`` remains accepted as a backwards-compatible alias. When the mode + is unset, startup selects ``api_key`` only if local API-key validation is + enabled; otherwise it selects ``none``. - ``AGENT_CONTROL_AUTH_MODE=http_upstream``: :class:`HttpUpstreamAuthProvider` pointed at ``AGENT_CONTROL_AUTH_UPSTREAM_URL``. @@ -190,11 +193,17 @@ def set_runtime_auth_config(config: RuntimeAuthConfig | None) -> None: def _build_default_provider() -> RequestAuthorizer: - mode = os.environ.get(_MODE_ENV, "api_key").strip().lower() + raw_mode = os.environ.get(_MODE_ENV) + mode = ( + raw_mode + if raw_mode is not None + else ("api_key" if auth_settings.api_key_enabled else "none") + ).strip().lower() if mode in {"none", "no_auth"}: _logger.info("Default auth provider: none") return NoAuthProvider() if mode in {"api_key", "header"}: + _validate_local_api_key_mode() _logger.info("Default auth provider: api_key (local credentials)") return HeaderAuthProvider() if mode == "http_upstream": @@ -223,6 +232,20 @@ def _build_default_provider() -> RequestAuthorizer: ) +def _validate_local_api_key_mode() -> None: + """Fail startup when local API-key mode has no local key validator.""" + if not auth_settings.api_key_enabled: + raise RuntimeError( + f"{_MODE_ENV}=api_key requires AGENT_CONTROL_API_KEY_ENABLED=true. " + f"Use {_MODE_ENV}=none for deployments without credential enforcement." + ) + if not auth_settings.get_api_keys() and not auth_settings.get_admin_api_keys(): + raise RuntimeError( + f"{_MODE_ENV}=api_key requires AGENT_CONTROL_API_KEYS or " + "AGENT_CONTROL_ADMIN_API_KEYS to be configured." + ) + + def _parse_extra_forward_headers(raw: str | None) -> tuple[str, ...]: """Parse a comma-separated header list into a deduplicated tuple. diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index 78ed9ae2..27c776bd 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -147,6 +147,18 @@ class HttpUpstreamConfig: dropped. Names duplicating the default set or each other (after case-folding) are deduplicated.""" + def __post_init__(self) -> None: + if self.service_token is None: + return + forwarded = { + name.lower() + for name in (*_DEFAULT_FORWARDED_HEADERS, *self.extra_forward_headers) + } + if self.service_token_header.lower() in forwarded: + raise ValueError( + "service_token_header must not match a forwarded caller credential header" + ) + class HttpUpstreamAuthProvider(RequestAuthorizer): """Delegates authorization to an upstream HTTP service.""" @@ -197,7 +209,7 @@ async def authorize( hint="Retry the request; if the failure persists, contact the operator.", ) from exc - return self._handle_response(response, operation) + return self._handle_response(response, operation, context) def _forward_headers(self, request: Request) -> dict[str, str]: headers: dict[str, str] = {} @@ -215,11 +227,16 @@ def _forward_headers(self, request: Request) -> dict[str, str]: return headers def _handle_response( - self, response: httpx.Response, operation: Operation + self, + response: httpx.Response, + operation: Operation, + context: dict[str, Any] | None, ) -> Principal: status = response.status_code if status == 200: - return self._parse_principal(response) + principal = self._parse_principal(response) + _ensure_target_context_matches_grant(context, principal) + return principal if status == 401: raise AuthenticationError( error_code=ErrorCode.AUTH_INVALID_KEY, @@ -309,3 +326,27 @@ def _parse_principal(self, response: httpx.Response) -> Principal: scopes=grant.scopes, grant_expires_at=grant.expires_at, ) + + +def _ensure_target_context_matches_grant( + context: dict[str, Any] | None, + principal: Principal, +) -> None: + """Reject target-bound grants that do not match the requested target.""" + if principal.target_type is None and principal.target_id is None: + return + if context is None: + return + + expected_type = context.get("target_type") + expected_id = context.get("target_id") + if not isinstance(expected_type, str) or not isinstance(expected_id, str): + return + if principal.target_type == expected_type and principal.target_id == expected_id: + return + + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="Authorization grant target does not match the requested target.", + hint="Retry with credentials authorized for the requested target.", + ) diff --git a/server/src/agent_control_server/auth_framework/runtime_token.py b/server/src/agent_control_server/auth_framework/runtime_token.py index a8eaa4e4..54c59fbb 100644 --- a/server/src/agent_control_server/auth_framework/runtime_token.py +++ b/server/src/agent_control_server/auth_framework/runtime_token.py @@ -92,6 +92,12 @@ def mint_runtime_token( ) if not namespace_key: raise RuntimeTokenError("namespace_key is required to mint a runtime token") + if not actor_id: + raise RuntimeTokenError("actor_id is required to mint a runtime token") + if not target_type: + raise RuntimeTokenError("target_type is required to mint a runtime token") + if not target_id: + raise RuntimeTokenError("target_id is required to mint a runtime token") if ttl_seconds <= 0: raise RuntimeTokenError("ttl_seconds must be positive") if upstream_expires_at is not None and ( diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index 1b380026..d29fbdfc 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -174,6 +174,23 @@ def _ensure_target_principal_matches_namespace( ) +async def _authorize_existing_agent_overwrite( + request: Request, + principal: Principal, +) -> None: + update_principal = await get_authorizer(Operation.AGENTS_UPDATE).authorize( + request, + Operation.AGENTS_UPDATE, + ) + if update_principal.namespace_key == principal.namespace_key: + return + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="Update authorization resolved to a different namespace.", + hint="Ensure the credential is scoped to the requested agent namespace.", + ) + + # ============================================================================= # List Agents Models # ============================================================================= @@ -532,6 +549,7 @@ async def list_agents( ) async def init_agent( request: InitAgentRequest, + http_request: Request, db: AsyncSession = Depends(get_async_db), principal: Principal = Depends(require_operation(Operation.AGENTS_CREATE)), target_principal: Principal | None = Depends(_init_agent_target_principal), @@ -664,6 +682,9 @@ async def init_agent( ) return InitAgentResponse(created=created, controls=controls) + if request.force_replace or request.conflict_mode == ConflictMode.OVERWRITE: + await _authorize_existing_agent_overwrite(http_request, principal) + # Parse existing data via AgentData Pydantic model try: data_model = AgentData.model_validate(existing.data) diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index 7125b64d..2d242ced 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -13,6 +13,7 @@ from __future__ import annotations +import hashlib from datetime import datetime from typing import Any @@ -34,6 +35,10 @@ _logger = get_logger(__name__) +def _log_hash(value: str) -> str: + return hashlib.sha256(value.encode("utf-8")).hexdigest()[:16] + + class RuntimeTokenExchangeRequest(BaseModel): """Body for the runtime token exchange endpoint.""" @@ -181,7 +186,7 @@ async def runtime_token_exchange( "Runtime token exchanged", extra={ "namespace_key": claims.namespace_key, - "actor_id": claims.actor_id, + "actor_id_hash": _log_hash(claims.actor_id), "target_type": claims.target_type, "target_id": claims.target_id, "scopes": list(claims.scopes), diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index 30779c5c..bc66381f 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -127,10 +127,13 @@ async def _evaluation_context(request: Request) -> dict[str, object]: return {} if not isinstance(body, dict): return {} - return { - "target_type": body.get("target_type"), - "target_id": body.get("target_id"), - } + target_type = body.get("target_type") + target_id = body.get("target_id") + if not isinstance(target_type, str) or not isinstance(target_id, str): + return {} + if not target_type or not target_id: + return {} + return {"target_type": target_type, "target_id": target_id} @router.post( diff --git a/server/src/agent_control_server/endpoints/observability.py b/server/src/agent_control_server/endpoints/observability.py index 3296ca1c..4e52377b 100644 --- a/server/src/agent_control_server/endpoints/observability.py +++ b/server/src/agent_control_server/endpoints/observability.py @@ -27,7 +27,7 @@ ) from fastapi import APIRouter, Depends, Request -from ..auth_framework import Operation, require_operation +from ..auth_framework import Operation, Principal, require_operation from ..observability.ingest.base import EventIngestor from ..observability.store.base import ( EventStore, @@ -75,11 +75,11 @@ def get_event_store(request: Request) -> EventStore: "/events", status_code=202, response_model=BatchEventsResponse, - dependencies=[Depends(require_operation(Operation.OBSERVABILITY_WRITE))], ) async def ingest_events( request: BatchEventsRequest, ingestor: EventIngestor = Depends(get_event_ingestor), + principal: Principal = Depends(require_operation(Operation.OBSERVABILITY_WRITE)), ) -> BatchEventsResponse: """ Ingest batched control execution events. @@ -95,7 +95,10 @@ async def ingest_events( """ start_time = time.perf_counter() - result = await ingestor.ingest(request.events) + result = await ingestor.ingest( + request.events, + namespace_key=principal.namespace_key, + ) duration_ms = (time.perf_counter() - start_time) * 1000 logger.debug( @@ -128,11 +131,11 @@ async def ingest_events( @router.post( "/events/query", response_model=EventQueryResponse, - dependencies=[Depends(require_operation(Operation.OBSERVABILITY_READ))], ) async def query_events( request: EventQueryRequest, store: EventStore = Depends(get_event_store), + principal: Principal = Depends(require_operation(Operation.OBSERVABILITY_READ)), ) -> EventQueryResponse: """ Query raw control execution events. @@ -158,7 +161,7 @@ async def query_events( Returns: EventQueryResponse with matching events and pagination info """ - return await store.query_events(request) + return await store.query_events(request, namespace_key=principal.namespace_key) # ============================================================================= @@ -169,13 +172,13 @@ async def query_events( @router.get( "/stats", response_model=StatsResponse, - dependencies=[Depends(require_operation(Operation.OBSERVABILITY_READ))], ) async def get_stats( agent_name: str, time_range: TimeRange = "5m", include_timeseries: bool = False, store: EventStore = Depends(get_event_store), + principal: Principal = Depends(require_operation(Operation.OBSERVABILITY_READ)), ) -> StatsResponse: """ Get agent-level aggregated statistics. @@ -202,6 +205,7 @@ async def get_stats( control_id=None, include_timeseries=include_timeseries, bucket_size=bucket_size, + namespace_key=principal.namespace_key, ) return StatsResponse( @@ -222,7 +226,6 @@ async def get_stats( @router.get( "/stats/controls/{control_id}", response_model=ControlStatsResponse, - dependencies=[Depends(require_operation(Operation.OBSERVABILITY_READ))], ) async def get_control_stats( control_id: int, @@ -230,6 +233,7 @@ async def get_control_stats( time_range: TimeRange = "5m", include_timeseries: bool = False, store: EventStore = Depends(get_event_store), + principal: Principal = Depends(require_operation(Operation.OBSERVABILITY_READ)), ) -> ControlStatsResponse: """ Get statistics for a single control. @@ -256,6 +260,7 @@ async def get_control_stats( control_id=control_id, include_timeseries=include_timeseries, bucket_size=bucket_size, + namespace_key=principal.namespace_key, ) # Get control name from the stats (should be exactly one) diff --git a/server/src/agent_control_server/migrate.py b/server/src/agent_control_server/migrate.py index 16483bb7..1528305f 100644 --- a/server/src/agent_control_server/migrate.py +++ b/server/src/agent_control_server/migrate.py @@ -9,16 +9,21 @@ import argparse import logging +import shutil import sys +import tempfile +from collections.abc import Iterator +from contextlib import contextmanager from pathlib import Path -from alembic import command from alembic.config import Config import agent_control_server +from alembic import command -def _bundled_config() -> Config: +@contextmanager +def _bundled_config() -> Iterator[Config]: pkg_dir = Path(agent_control_server.__file__).parent ini_path = pkg_dir / "_alembic.ini" alembic_dir = pkg_dir / "_alembic" @@ -28,9 +33,15 @@ def _bundled_config() -> Config: f"{ini_path} and {alembic_dir}. The installed wheel is missing " "migration assets." ) - cfg = Config(str(ini_path)) - cfg.set_main_option("script_location", str(alembic_dir).replace("%", "%%")) - return cfg + with tempfile.TemporaryDirectory(prefix="agent-control-alembic-") as tmp: + script_location = Path(tmp) / "_alembic" + shutil.copytree(alembic_dir, script_location) + for injected_init in (script_location / "versions").rglob("__init__.py"): + injected_init.unlink() + + cfg = Config(str(ini_path)) + cfg.set_main_option("script_location", str(script_location).replace("%", "%%")) + yield cfg def _build_parser() -> argparse.ArgumentParser: @@ -78,19 +89,19 @@ def main(argv: list[str] | None = None) -> int: _configure_logging() try: - cfg = _bundled_config() - if parsed.command == "upgrade": - command.upgrade(cfg, parsed.revision, sql=parsed.sql) - elif parsed.command == "downgrade": - command.downgrade(cfg, parsed.revision, sql=parsed.sql) - elif parsed.command == "current": - command.current(cfg) - elif parsed.command == "history": - command.history(cfg) - elif parsed.command == "heads": - command.heads(cfg) - else: # pragma: no cover - argparse guarantees this cannot happen. - parser.error("missing command") + with _bundled_config() as cfg: + if parsed.command == "upgrade": + command.upgrade(cfg, parsed.revision, sql=parsed.sql) + elif parsed.command == "downgrade": + command.downgrade(cfg, parsed.revision, sql=parsed.sql) + elif parsed.command == "current": + command.current(cfg) + elif parsed.command == "history": + command.history(cfg) + elif parsed.command == "heads": + command.heads(cfg) + else: # pragma: no cover - argparse guarantees this cannot happen. + parser.error("missing command") except Exception as exc: print(f"agent-control-migrate: {exc}", file=sys.stderr) return 1 diff --git a/server/src/agent_control_server/models.py b/server/src/agent_control_server/models.py index a218ecd6..a0dfc0ed 100644 --- a/server/src/agent_control_server/models.py +++ b/server/src/agent_control_server/models.py @@ -341,12 +341,12 @@ class ControlExecutionEventDB(Base): Raw control execution events with minimal indexed columns + JSONB. Schema designed for simplicity and flexibility: - - Only 4 columns: control_execution_id, timestamp, agent_name, data + - Indexed columns: namespace_key, control_execution_id, timestamp, agent_name - Full event stored in JSONB 'data' column - Query-time aggregation from JSONB fields - No migrations needed for new event fields - Primary access pattern: (agent_name, timestamp DESC) for stats queries. + Primary access pattern: (namespace_key, agent_name, timestamp DESC) for stats queries. Expression index on (data->>'control_id') for grouping. """ @@ -358,6 +358,11 @@ class ControlExecutionEventDB(Base): ) # Minimal indexed columns for efficient queries + namespace_key: Mapped[str] = mapped_column( + String(255), + nullable=False, + server_default=_NAMESPACE_SERVER_DEFAULT, + ) timestamp: Mapped[dt.datetime] = mapped_column( DateTime(timezone=True), server_default=text("CURRENT_TIMESTAMP"), @@ -372,6 +377,7 @@ class ControlExecutionEventDB(Base): # Composite index for agent + time queries (primary access pattern) __table_args__ = ( + Index("ix_events_namespace_agent_time", "namespace_key", "agent_name", timestamp.desc()), Index("ix_events_agent_time", "agent_name", timestamp.desc()), Index("ix_events_data_control_id", text("(data ->> 'control_id'::text)")), ) diff --git a/server/src/agent_control_server/observability/ingest/base.py b/server/src/agent_control_server/observability/ingest/base.py index 6f278893..cc06c3b5 100644 --- a/server/src/agent_control_server/observability/ingest/base.py +++ b/server/src/agent_control_server/observability/ingest/base.py @@ -10,6 +10,8 @@ from agent_control_models.observability import ControlExecutionEvent from pydantic import BaseModel, Field +from ...models import DEFAULT_NAMESPACE_KEY + class IngestResult(BaseModel): """Result of an event ingestion operation. @@ -40,11 +42,17 @@ class EventIngestor(Protocol): - KafkaEventIngestor: Pushes to Kafka topic """ - async def ingest(self, events: list[ControlExecutionEvent]) -> IngestResult: + async def ingest( + self, + events: list[ControlExecutionEvent], + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, + ) -> IngestResult: """Ingest events. Returns counts of received/processed/dropped. Args: events: List of control execution events to ingest + namespace_key: Namespace that owns the events Returns: IngestResult with counts of received, processed, and dropped events diff --git a/server/src/agent_control_server/observability/ingest/direct.py b/server/src/agent_control_server/observability/ingest/direct.py index 37f7d3e8..3b63afa8 100644 --- a/server/src/agent_control_server/observability/ingest/direct.py +++ b/server/src/agent_control_server/observability/ingest/direct.py @@ -15,6 +15,7 @@ from agent_control_models.observability import ControlExecutionEvent from agent_control_telemetry.sinks import AsyncControlEventSink +from ...models import DEFAULT_NAMESPACE_KEY from ..sinks import EventStoreControlEventSink from ..store.base import EventStore from .base import EventIngestor, IngestResult @@ -53,11 +54,17 @@ def __init__( self.sink = store self.log_to_stdout = log_to_stdout - async def ingest(self, events: list[ControlExecutionEvent]) -> IngestResult: + async def ingest( + self, + events: list[ControlExecutionEvent], + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, + ) -> IngestResult: """Ingest events by writing them directly to the configured sink. Args: events: List of control execution events to ingest + namespace_key: Namespace that owns the events Returns: IngestResult with counts of received, processed, and dropped events @@ -70,7 +77,13 @@ async def ingest(self, events: list[ControlExecutionEvent]) -> IngestResult: dropped = 0 try: - sink_result = await self.sink.write_events(events) + if isinstance(self.sink, EventStoreControlEventSink): + sink_result = await self.sink.write_events( + events, + namespace_key=namespace_key, + ) + else: + sink_result = await self.sink.write_events(events) processed = sink_result.accepted dropped = sink_result.dropped diff --git a/server/src/agent_control_server/observability/sinks.py b/server/src/agent_control_server/observability/sinks.py index 321979b0..11273ce8 100644 --- a/server/src/agent_control_server/observability/sinks.py +++ b/server/src/agent_control_server/observability/sinks.py @@ -17,6 +17,7 @@ ) from agent_control_telemetry.sink_selection import SinkSelectionError +from ..models import DEFAULT_NAMESPACE_KEY from .store.base import EventStore _named_event_sink_factories: ControlEventSinkFactoryRegistry[ResolvedControlEventBackend] = ( @@ -30,9 +31,14 @@ class EventStoreControlEventSink: def __init__(self, store: EventStore): self.store = store - async def write_events(self, events: Sequence[ControlExecutionEvent]) -> SinkResult: + async def write_events( + self, + events: Sequence[ControlExecutionEvent], + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, + ) -> SinkResult: """Write events to the underlying store and report accepted/dropped counts.""" - stored = await self.store.store(list(events)) + stored = await self.store.store(list(events), namespace_key=namespace_key) dropped = max(len(events) - stored, 0) return SinkResult(accepted=stored, dropped=dropped) diff --git a/server/src/agent_control_server/observability/store/base.py b/server/src/agent_control_server/observability/store/base.py index 2c78f983..4249fcdc 100644 --- a/server/src/agent_control_server/observability/store/base.py +++ b/server/src/agent_control_server/observability/store/base.py @@ -25,6 +25,8 @@ ) from pydantic import BaseModel, Field +from ...models import DEFAULT_NAMESPACE_KEY + # Type alias for time range literals TimeRange = Literal["1m", "5m", "15m", "1h", "24h", "7d", "30d", "180d", "365d"] @@ -119,11 +121,17 @@ class EventStore(ABC): """ @abstractmethod - async def store(self, events: list[ControlExecutionEvent]) -> int: + async def store( + self, + events: list[ControlExecutionEvent], + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, + ) -> int: """Store raw events. Args: events: List of control execution events to store + namespace_key: Namespace that owns the stored events Returns: Number of events successfully stored @@ -138,6 +146,7 @@ async def query_stats( control_id: int | None = None, include_timeseries: bool = False, bucket_size: timedelta | None = None, + namespace_key: str = DEFAULT_NAMESPACE_KEY, ) -> StatsResult: """Query stats (aggregated at query time from raw events). @@ -147,6 +156,7 @@ async def query_stats( control_id: Optional control ID to filter by include_timeseries: Whether to include time-series data bucket_size: Bucket size for time-series (required if include_timeseries=True) + namespace_key: Namespace whose events should be queried Returns: StatsResult with per-control and total statistics @@ -154,11 +164,17 @@ async def query_stats( pass @abstractmethod - async def query_events(self, query: EventQuery) -> EventQueryResult: + async def query_events( + self, + query: EventQuery, + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, + ) -> EventQueryResult: """Query raw events with filters and pagination. Args: query: Query parameters (filters, pagination) + namespace_key: Namespace whose events should be queried Returns: EventQueryResult with matching events and pagination info diff --git a/server/src/agent_control_server/observability/store/postgres.py b/server/src/agent_control_server/observability/store/postgres.py index aff1931e..39972f04 100644 --- a/server/src/agent_control_server/observability/store/postgres.py +++ b/server/src/agent_control_server/observability/store/postgres.py @@ -24,6 +24,7 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from ...models import DEFAULT_NAMESPACE_KEY from .base import EventStore, StatsResult logger = logging.getLogger(__name__) @@ -106,11 +107,17 @@ def __init__(self, session_maker: async_sessionmaker[AsyncSession]): """ self.session_maker = session_maker - async def store(self, events: list[ControlExecutionEvent]) -> int: + async def store( + self, + events: list[ControlExecutionEvent], + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, + ) -> int: """Store raw events in PostgreSQL. Uses batch insert with ON CONFLICT DO NOTHING for idempotency. - The simplified schema stores only 4 columns: + The simplified schema stores only a few indexed columns: + - namespace_key (tenant partition) - control_execution_id (PK) - timestamp (indexed) - agent_name (indexed) @@ -118,6 +125,7 @@ async def store(self, events: list[ControlExecutionEvent]) -> int: Args: events: List of control execution events to store + namespace_key: Namespace that owns the events Returns: Number of events successfully stored @@ -125,13 +133,14 @@ async def store(self, events: list[ControlExecutionEvent]) -> int: if not events: return 0 - # Build values for batch insert (only 4 columns) + # Build values for batch insert. values = [] for event in events: # Serialize the full event to JSONB event_data = event.model_dump(mode="json") values.append({ + "namespace_key": namespace_key, "control_execution_id": event.control_execution_id, "timestamp": event.timestamp, "agent_name": event.agent_name, @@ -139,13 +148,13 @@ async def store(self, events: list[ControlExecutionEvent]) -> int: }) async with self.session_maker() as session: - # Batch insert with minimal columns + # Batch insert with minimal indexed columns plus JSONB event data. await session.execute( text(""" INSERT INTO control_execution_events ( - control_execution_id, timestamp, agent_name, data + namespace_key, control_execution_id, timestamp, agent_name, data ) VALUES ( - :control_execution_id, :timestamp, :agent_name, + :namespace_key, :control_execution_id, :timestamp, :agent_name, CAST(:data AS JSONB) ) ON CONFLICT (control_execution_id) DO NOTHING @@ -164,6 +173,7 @@ async def query_stats( control_id: int | None = None, include_timeseries: bool = False, bucket_size: timedelta | None = None, + namespace_key: str = DEFAULT_NAMESPACE_KEY, ) -> StatsResult: """Query stats aggregated at query time from raw events. @@ -179,6 +189,7 @@ async def query_stats( control_id: Optional control ID to filter by include_timeseries: Whether to include time-series data bucket_size: Bucket size for time-series (required if include_timeseries=True) + namespace_key: Namespace whose events should be queried Returns: StatsResult with per-control and total statistics @@ -187,6 +198,7 @@ async def query_stats( cutoff = now - time_range params: dict = { + "namespace_key": namespace_key, "agent_name": agent_name, "cutoff": cutoff, } @@ -208,7 +220,8 @@ async def query_stats( WITH filtered_events AS ( SELECT timestamp, data FROM control_execution_events - WHERE agent_name = :agent_name + WHERE namespace_key = :namespace_key + AND agent_name = :agent_name AND timestamp >= :cutoff {control_filter} ), @@ -277,7 +290,8 @@ async def query_stats( NULL::timestamptz as bucket, {SQL_STATS_AGGREGATIONS} FROM control_execution_events - WHERE agent_name = :agent_name + WHERE namespace_key = :namespace_key + AND agent_name = :agent_name AND timestamp >= :cutoff {control_filter} GROUP BY data->>'control_id', data->>'control_name' @@ -380,7 +394,12 @@ def _timedelta_to_interval(self, td: timedelta) -> str: else: return f"{total_seconds} seconds" - async def query_events(self, query: EventQueryRequest) -> EventQueryResponse: + async def query_events( + self, + query: EventQueryRequest, + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, + ) -> EventQueryResponse: """Query raw events with filters and pagination. Supports filtering by trace_id, span_id, agent_name, control_ids, @@ -396,8 +415,8 @@ async def query_events(self, query: EventQueryRequest) -> EventQueryResponse: EventQueryResponse with matching events and pagination info """ # Build WHERE clauses and params - where_clauses = [] - params: dict = {} + where_clauses = ["namespace_key = :namespace_key"] + params: dict = {"namespace_key": namespace_key} # Indexed columns (use direct comparison) if query.control_execution_id: diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 06f1be89..83276744 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -7,6 +7,7 @@ import httpx import pytest + from agent_control_server.auth_framework.core import ( Operation, Principal, @@ -28,6 +29,7 @@ from agent_control_server.auth_framework.providers.http_upstream import ( HttpUpstreamConfig, ) +from agent_control_server.config import auth_settings from agent_control_server.errors import ( APIError, AuthenticationError, @@ -49,6 +51,16 @@ def _build_request( return request +def _clear_auth_settings_cache() -> None: + for attr in ( + "_parsed_api_keys", + "_parsed_admin_api_keys", + "_all_valid_keys", + "_all_admin_keys", + ): + auth_settings.__dict__.pop(attr, None) + + # 32-byte test secret (HS256 wants >= 32 bytes; shorter raises a warning). _TEST_SECRET = "test-runtime-secret-12345678901234567890" _OTHER_SECRET = "other-runtime-secret-1234567890123456789" @@ -260,6 +272,25 @@ def factory(request: httpx.Request) -> httpx.Response: assert captured["headers"]["x-custom-token"] == "shh" +def test_http_upstream_rejects_service_token_header_collision(): + with pytest.raises(ValueError, match="service_token_header"): + HttpUpstreamConfig( + url="https://upstream.example/check", + service_token="shh", + service_token_header="Authorization", + ) + + +def test_http_upstream_rejects_extra_forwarded_service_token_header_collision(): + with pytest.raises(ValueError, match="service_token_header"): + HttpUpstreamConfig( + url="https://upstream.example/check", + service_token="shh", + service_token_header="X-Custom-Auth", + extra_forward_headers=("x-custom-auth",), + ) + + @pytest.mark.asyncio async def test_http_upstream_forwards_extra_headers(): # Given: a provider configured with an extra header in its forward list @@ -696,9 +727,39 @@ def test_runtime_token_rejects_naive_upstream_expires_at(): ) +@pytest.mark.parametrize( + "kwargs, message", + [ + ({"actor_id": ""}, "actor_id is required"), + ({"target_type": ""}, "target_type is required"), + ({"target_id": ""}, "target_id is required"), + ], +) +def test_runtime_token_rejects_empty_required_claims(kwargs, message): + from agent_control_server.auth_framework.runtime_token import ( + RuntimeTokenError, + mint_runtime_token, + ) + + token_kwargs = { + "namespace_key": "default", + "actor_id": "actor", + "target_type": "target", + "target_id": "target-id", + "scopes": ("runtime.use",), + "secret": _TEST_SECRET, + "ttl_seconds": 60, + } + token_kwargs.update(kwargs) + + with pytest.raises(RuntimeTokenError, match=message): + mint_runtime_token(**token_kwargs) + + def test_runtime_token_rejects_management_token_passed_to_runtime_verify(): """A token without ``domain=runtime`` must be rejected by runtime verify.""" import jwt + from agent_control_server.auth_framework.runtime_token import ( RuntimeTokenError, verify_runtime_token, @@ -1025,6 +1086,27 @@ async def test_http_upstream_accepts_iso_datetime_and_array_scopes(): assert principal.grant_expires_at.isoformat() == iso_expiry +@pytest.mark.asyncio +async def test_http_upstream_rejects_target_grant_mismatch(): + provider = _build_upstream( + lambda req: httpx.Response( + 200, + json={ + "namespace_key": "org-1", + "target_type": "log_stream", + "target_id": "different", + }, + ) + ) + + with pytest.raises(ForbiddenError, match="does not match"): + await provider.authorize( + _build_request(), + Operation.RUNTIME_TOKEN_EXCHANGE, + context={"target_type": "log_stream", "target_id": "requested"}, + ) + + # --------------------------------------------------------------------------- # configure_auth_from_env / teardown_auth lifecycle # --------------------------------------------------------------------------- @@ -1084,6 +1166,42 @@ def test_build_default_provider_accepts_none_mode(monkeypatch): assert isinstance(auth_config._build_default_provider(), NoAuthProvider) +def test_build_default_provider_defaults_to_none_when_api_keys_disabled(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.delenv("AGENT_CONTROL_AUTH_MODE", raising=False) + monkeypatch.setattr(auth_settings, "api_key_enabled", False) + + assert isinstance(auth_config._build_default_provider(), NoAuthProvider) + + +def test_build_default_provider_rejects_explicit_api_key_without_validator( + monkeypatch, +): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "api_key") + monkeypatch.setattr(auth_settings, "api_key_enabled", False) + + with pytest.raises(RuntimeError, match="AGENT_CONTROL_API_KEY_ENABLED=true"): + auth_config._build_default_provider() + + +def test_build_default_provider_rejects_explicit_api_key_without_keys( + monkeypatch, +): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "api_key") + monkeypatch.setattr(auth_settings, "api_key_enabled", True) + monkeypatch.setattr(auth_settings, "api_keys", "") + monkeypatch.setattr(auth_settings, "admin_api_keys", "") + _clear_auth_settings_cache() + + with pytest.raises(RuntimeError, match="AGENT_CONTROL_API_KEYS"): + auth_config._build_default_provider() + + def test_resolve_runtime_mode_defaults_to_default_without_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config diff --git a/server/tests/test_init_agent_conflict_mode.py b/server/tests/test_init_agent_conflict_mode.py index 8a9e2ce1..2e8b9b80 100644 --- a/server/tests/test_init_agent_conflict_mode.py +++ b/server/tests/test_init_agent_conflict_mode.py @@ -6,11 +6,32 @@ from copy import deepcopy from typing import Any +from agent_control_models.errors import ErrorCode +from fastapi import Request from fastapi.testclient import TestClient +from agent_control_server.auth_framework import Operation, Principal, set_authorizer +from agent_control_server.errors import ForbiddenError + from .utils import VALID_CONTROL_PAYLOAD +class CreateOnlyAuthorizer: + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request, context + if operation is Operation.AGENTS_UPDATE: + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="update denied", + ) + return Principal(namespace_key="default", is_admin=True) + + def _init_payload( *, agent_name: str, @@ -153,6 +174,44 @@ def test_init_agent_overwrite_replaces_steps_and_evaluators(client: TestClient) assert {evaluator["name"] for evaluator in get_data["evaluators"]} == {"eval-a", "eval-c"} +def test_init_agent_overwrite_existing_agent_requires_update_auth( + client: TestClient, +) -> None: + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + create_resp = client.post( + "/api/v1/agents/initAgent", + json=_init_payload(agent_name=agent_name), + ) + assert create_resp.status_code == 200 + + set_authorizer(CreateOnlyAuthorizer()) + overwrite_resp = client.post( + "/api/v1/agents/initAgent", + json=_init_payload(agent_name=agent_name, conflict_mode="overwrite"), + ) + + assert overwrite_resp.status_code == 403 + + +def test_init_agent_force_replace_existing_agent_requires_update_auth( + client: TestClient, +) -> None: + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + create_resp = client.post( + "/api/v1/agents/initAgent", + json=_init_payload(agent_name=agent_name), + ) + assert create_resp.status_code == 200 + + set_authorizer(CreateOnlyAuthorizer()) + force_resp = client.post( + "/api/v1/agents/initAgent", + json={**_init_payload(agent_name=agent_name), "force_replace": True}, + ) + + assert force_resp.status_code == 403 + + def test_init_agent_overwrite_warns_on_removed_referenced_evaluator(client: TestClient) -> None: # Given: an agent whose assigned policy contains a control referencing an agent evaluator. agent_name = f"agent-{uuid.uuid4().hex[:12]}" diff --git a/server/tests/test_migrate.py b/server/tests/test_migrate.py new file mode 100644 index 00000000..9024d909 --- /dev/null +++ b/server/tests/test_migrate.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from pathlib import Path + +import agent_control_server +from agent_control_server import migrate + + +def test_bundled_config_omits_injected_version_init( + tmp_path: Path, + monkeypatch, +) -> None: + package_dir = tmp_path / "agent_control_server" + versions_dir = package_dir / "_alembic" / "versions" + versions_dir.mkdir(parents=True) + (package_dir / "__init__.py").write_text("", encoding="utf-8") + (package_dir / "_alembic.ini").write_text( + "[alembic]\nscript_location = _alembic\n", + encoding="utf-8", + ) + (package_dir / "_alembic" / "env.py").write_text("", encoding="utf-8") + (versions_dir / "__init__.py").write_text("", encoding="utf-8") + (versions_dir / "abc123_example.py").write_text("revision = 'abc123'\n", encoding="utf-8") + + monkeypatch.setattr(agent_control_server, "__file__", str(package_dir / "__init__.py")) + + with migrate._bundled_config() as cfg: + script_location = Path(cfg.get_main_option("script_location")) + assert script_location.exists() + assert (script_location / "versions" / "abc123_example.py").exists() + assert not (script_location / "versions" / "__init__.py").exists() + + assert not script_location.exists() diff --git a/server/tests/test_observability_direct_ingest.py b/server/tests/test_observability_direct_ingest.py index f3f6db81..13e7270c 100644 --- a/server/tests/test_observability_direct_ingest.py +++ b/server/tests/test_observability_direct_ingest.py @@ -3,23 +3,34 @@ import logging import pytest - -from uuid import uuid4 - from agent_control_models.observability import ControlExecutionEvent from agent_control_telemetry.sinks import SinkResult + from agent_control_server.observability.ingest.direct import DirectEventIngestor from agent_control_server.observability.store.base import EventStore class FailingStore(EventStore): - async def store(self, events: list[ControlExecutionEvent]) -> int: + async def store( + self, + events: list[ControlExecutionEvent], + *, + namespace_key: str = "default", + ) -> int: raise RuntimeError("boom") - async def query_stats(self, agent_name, time_range, control_id=None): # pragma: no cover - not used + async def query_stats( + self, + agent_name, + time_range, + control_id=None, + include_timeseries=False, + bucket_size=None, + namespace_key="default", + ): # pragma: no cover - not used raise NotImplementedError - async def query_events(self, query): # pragma: no cover - not used + async def query_events(self, query, *, namespace_key="default"): # pragma: no cover - not used raise NotImplementedError @@ -27,14 +38,27 @@ class CountingStore(EventStore): def __init__(self) -> None: self.calls: list[list[ControlExecutionEvent]] = [] - async def store(self, events: list[ControlExecutionEvent]) -> int: + async def store( + self, + events: list[ControlExecutionEvent], + *, + namespace_key: str = "default", + ) -> int: self.calls.append(events) return len(events) - async def query_stats(self, agent_name, time_range, control_id=None): # pragma: no cover - not used + async def query_stats( + self, + agent_name, + time_range, + control_id=None, + include_timeseries=False, + bucket_size=None, + namespace_key="default", + ): # pragma: no cover - not used raise NotImplementedError - async def query_events(self, query): # pragma: no cover - not used + async def query_events(self, query, *, namespace_key="default"): # pragma: no cover - not used raise NotImplementedError diff --git a/server/tests/test_observability_endpoints.py b/server/tests/test_observability_endpoints.py index 97fc0f7c..114713c7 100644 --- a/server/tests/test_observability_endpoints.py +++ b/server/tests/test_observability_endpoints.py @@ -1,7 +1,7 @@ """Tests for observability API endpoints.""" import json -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from typing import Any from uuid import UUID, uuid4 @@ -40,7 +40,7 @@ def create_test_event( action=action, matched=matched, confidence=0.95, - timestamp=timestamp or datetime.now(timezone.utc), + timestamp=timestamp or datetime.now(UTC), execution_duration_ms=execution_duration_ms, ) @@ -62,6 +62,20 @@ async def authorize( return Principal(namespace_key="default") +class _NamespaceAuthorizer: + def __init__(self, namespace_key: str) -> None: + self.namespace_key = namespace_key + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request, operation, context + return Principal(namespace_key=self.namespace_key) + + class TestObservabilityAuthFramework: """Tests observability routes declare operation-based authorization.""" @@ -102,6 +116,40 @@ def test_ingest_events_uses_write_operation( assert response.status_code == 202 assert authorizer.calls == [(Operation.OBSERVABILITY_WRITE, None)] + def test_events_are_scoped_to_authorized_namespace( + self, + app: object, + setup_observability: object, + ) -> None: + _ = setup_observability + client = TestClient(app, raise_server_exceptions=True) + agent_name = f"agent-{uuid4().hex[:12]}" + request = BatchEventsRequest(events=[create_test_event(agent_name=agent_name)]) + + set_authorizer(_NamespaceAuthorizer("tenant-a")) + ingest = client.post( + "/api/v1/observability/events", + json=request.model_dump(mode="json"), + ) + assert ingest.status_code == 202 + + query = EventQueryRequest(agent_name=agent_name, limit=10, offset=0) + set_authorizer(_NamespaceAuthorizer("tenant-b")) + tenant_b = client.post( + "/api/v1/observability/events/query", + json=query.model_dump(mode="json"), + ) + assert tenant_b.status_code == 200 + assert tenant_b.json()["total"] == 0 + + set_authorizer(_NamespaceAuthorizer("tenant-a")) + tenant_a = client.post( + "/api/v1/observability/events/query", + json=query.model_dump(mode="json"), + ) + assert tenant_a.status_code == 200 + assert tenant_a.json()["total"] == 1 + class TestEventIngestion: """Tests for POST /events endpoint.""" @@ -224,7 +272,7 @@ def test_event_with_all_fields(self): action="deny", matched=True, confidence=0.99, - timestamp=datetime.now(timezone.utc), + timestamp=datetime.now(UTC), execution_duration_ms=15.5, evaluator_name="regex", selector_path="input", @@ -330,7 +378,7 @@ async def test_stats_with_timeseries(self, client: TestClient, setup_observabili """With include_timeseries=true, returns buckets.""" store = setup_observability agent_name = f"agent-{uuid4().hex[:12]}" - now = datetime.now(timezone.utc) + now = datetime.now(UTC) # Create events spread across time events = [ @@ -387,7 +435,7 @@ async def test_timeseries_bucket_count_1h(self, client: TestClient, setup_observ """Verify reasonable number of buckets for 1h time range (5m buckets).""" store = setup_observability agent_name = f"agent-{uuid4().hex[:12]}" - now = datetime.now(timezone.utc) + now = datetime.now(UTC) # Create a single event event = create_test_event( @@ -418,7 +466,7 @@ async def test_timeseries_bucket_count_5m(self, client: TestClient, setup_observ """Verify reasonable number of buckets for 5m time range (30s buckets).""" store = setup_observability agent_name = f"agent-{uuid4().hex[:12]}" - now = datetime.now(timezone.utc) + now = datetime.now(UTC) # Create a single event event = create_test_event( @@ -451,7 +499,7 @@ async def test_timeseries_aggregates_events_per_bucket( """Events in the same bucket are aggregated.""" store = setup_observability agent_name = f"agent-{uuid4().hex[:12]}" - now = datetime.now(timezone.utc) + now = datetime.now(UTC) # Create multiple events in the same 5-minute bucket base_time = now - timedelta(minutes=10) @@ -516,7 +564,7 @@ async def test_timeseries_empty_buckets_included(self, client: TestClient, setup """Empty buckets are included with zero counts.""" store = setup_observability agent_name = f"agent-{uuid4().hex[:12]}" - now = datetime.now(timezone.utc) + now = datetime.now(UTC) # Create events only at the start of the time range event = create_test_event( @@ -602,7 +650,7 @@ async def test_control_stats_with_timeseries(self, client: TestClient, setup_obs """Test control stats with timeseries.""" store = setup_observability agent_name = f"agent-{uuid4().hex[:12]}" - now = datetime.now(timezone.utc) + now = datetime.now(UTC) # Create events for control 1 at different times events = [ @@ -855,7 +903,7 @@ def test_ingest_events_partial_status(self, client: TestClient, setup_observabil # Given: a stub ingestor that drops some events class StubIngestor: - async def ingest(self, events): + async def ingest(self, events, *, namespace_key="default"): return IngestResult(received=len(events), processed=1, dropped=len(events) - 1) original = app.state.event_ingestor @@ -885,7 +933,7 @@ def test_ingest_events_failed_status(self, client: TestClient, setup_observabili # Given: a stub ingestor that drops all events class StubIngestor: - async def ingest(self, events): + async def ingest(self, events, *, namespace_key="default"): return IngestResult(received=len(events), processed=0, dropped=len(events)) original = app.state.event_ingestor diff --git a/server/tests/test_observability_store_postgres.py b/server/tests/test_observability_store_postgres.py index 72cd4ab6..2786130b 100644 --- a/server/tests/test_observability_store_postgres.py +++ b/server/tests/test_observability_store_postgres.py @@ -5,11 +5,12 @@ from uuid import uuid4 import pytest +from agent_control_models.observability import ControlExecutionEvent, EventQueryRequest from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from agent_control_models.observability import ControlExecutionEvent, EventQueryRequest from agent_control_server.observability.store.postgres import PostgresEventStore + from .conftest import async_engine, engine @@ -122,6 +123,50 @@ async def test_postgres_event_store_query_events_and_stats() -> None: assert filtered_stats.stats[0].control_id == 1 +@pytest.mark.asyncio +async def test_postgres_event_store_scopes_queries_by_namespace() -> None: + session_maker = async_sessionmaker( + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + store = PostgresEventStore(session_maker) + + agent_name = f"agent-{uuid4().hex[:12]}" + now = datetime.now(UTC) + event_a = _event( + agent_name=agent_name, + control_id=1, + action="observe", + matched=True, + timestamp=now, + trace_id="a" * 32, + ) + event_b = _event( + agent_name=agent_name, + control_id=2, + action="deny", + matched=True, + timestamp=now, + trace_id="b" * 32, + ) + + await store.store([event_a], namespace_key="tenant-a") + await store.store([event_b], namespace_key="tenant-b") + + query = EventQueryRequest(agent_name=agent_name, limit=10, offset=0) + events_a = await store.query_events(query, namespace_key="tenant-a") + stats_a = await store.query_stats( + agent_name, + timedelta(hours=1), + namespace_key="tenant-a", + ) + + assert [event.control_id for event in events_a.events] == [1] + assert stats_a.total_executions == 1 + assert stats_a.stats[0].control_id == 1 + + @pytest.mark.asyncio async def test_postgres_event_store_store_empty_returns_zero() -> None: # Given: a Postgres-backed store diff --git a/server/tests/test_runtime_token_exchange_endpoint.py b/server/tests/test_runtime_token_exchange_endpoint.py index 0863c9a0..a59e9e85 100644 --- a/server/tests/test_runtime_token_exchange_endpoint.py +++ b/server/tests/test_runtime_token_exchange_endpoint.py @@ -8,9 +8,12 @@ from __future__ import annotations +import logging from datetime import UTC, datetime, timedelta import pytest +from fastapi.testclient import TestClient + from agent_control_server.auth_framework import Operation, Principal from agent_control_server.auth_framework.config import ( RuntimeAuthConfig, @@ -23,7 +26,6 @@ from agent_control_server.auth_framework.providers import ( LocalJwtVerifyProvider, ) -from fastapi.testclient import TestClient _TEST_SECRET = "test-runtime-secret-12345678901234567890" @@ -107,6 +109,38 @@ def test_exchange_endpoint_mints_token_when_configured(client: TestClient, runti assert body["expires_at"] +def test_exchange_audit_log_redacts_actor_id( + client: TestClient, + runtime_config_enabled, + caplog: pytest.LogCaptureFixture, +): + stub = _StubExchangeAuthorizer( + actor_id="user@example.test", + scopes=("runtime.use",), + target_type="log_stream", + target_id="ls-42", + ) + clear_authorizers() + set_authorizer(stub) + + with caplog.at_level(logging.INFO): + response = client.post( + "/api/v1/auth/runtime-token-exchange", + json={"target_type": "log_stream", "target_id": "ls-42"}, + ) + + assert response.status_code == 200, response.text + records = [ + record + for record in caplog.records + if record.getMessage() == "Runtime token exchanged" + ] + assert records + record = records[-1] + assert "actor_id" not in record.__dict__ + assert record.__dict__["actor_id_hash"] + + def test_exchange_endpoint_rejects_target_mismatch(client: TestClient, runtime_config_enabled): """Provider says the credential is scoped to one target; body asks for another.""" stub = _StubExchangeAuthorizer( From 6caf1b98c333d5f2f416d8bf846ec4b78ab8052e Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 14 May 2026 23:59:48 +0530 Subject: [PATCH 14/20] fix(server): keep bundled migration config reusable --- server/src/agent_control_server/migrate.py | 28 ++++++++++++++++------ server/tests/test_migrate.py | 2 +- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/server/src/agent_control_server/migrate.py b/server/src/agent_control_server/migrate.py index 1528305f..f861c8db 100644 --- a/server/src/agent_control_server/migrate.py +++ b/server/src/agent_control_server/migrate.py @@ -15,15 +15,15 @@ from collections.abc import Iterator from contextlib import contextmanager from pathlib import Path +from typing import cast +from alembic import command from alembic.config import Config import agent_control_server -from alembic import command -@contextmanager -def _bundled_config() -> Iterator[Config]: +def _bundled_config() -> Config: pkg_dir = Path(agent_control_server.__file__).parent ini_path = pkg_dir / "_alembic.ini" alembic_dir = pkg_dir / "_alembic" @@ -33,13 +33,27 @@ def _bundled_config() -> Iterator[Config]: f"{ini_path} and {alembic_dir}. The installed wheel is missing " "migration assets." ) + cfg = Config(str(ini_path)) + cfg.set_main_option("script_location", str(alembic_dir).replace("%", "%%")) + return cfg + + +@contextmanager +def _runtime_bundled_config() -> Iterator[Config]: + cfg = _bundled_config() + if not isinstance(cfg, Config): + yield cast(Config, cfg) + return + + bundled_script_location = cfg.get_main_option("script_location") + if bundled_script_location is None: + raise RuntimeError("Bundled Alembic script_location is not configured.") + with tempfile.TemporaryDirectory(prefix="agent-control-alembic-") as tmp: script_location = Path(tmp) / "_alembic" - shutil.copytree(alembic_dir, script_location) + shutil.copytree(bundled_script_location, script_location) for injected_init in (script_location / "versions").rglob("__init__.py"): injected_init.unlink() - - cfg = Config(str(ini_path)) cfg.set_main_option("script_location", str(script_location).replace("%", "%%")) yield cfg @@ -89,7 +103,7 @@ def main(argv: list[str] | None = None) -> int: _configure_logging() try: - with _bundled_config() as cfg: + with _runtime_bundled_config() as cfg: if parsed.command == "upgrade": command.upgrade(cfg, parsed.revision, sql=parsed.sql) elif parsed.command == "downgrade": diff --git a/server/tests/test_migrate.py b/server/tests/test_migrate.py index 9024d909..c6430e0e 100644 --- a/server/tests/test_migrate.py +++ b/server/tests/test_migrate.py @@ -24,7 +24,7 @@ def test_bundled_config_omits_injected_version_init( monkeypatch.setattr(agent_control_server, "__file__", str(package_dir / "__init__.py")) - with migrate._bundled_config() as cfg: + with migrate._runtime_bundled_config() as cfg: script_location = Path(cfg.get_main_option("script_location")) assert script_location.exists() assert (script_location / "versions" / "abc123_example.py").exists() From 1d5cfef57027192f5196d4a31be9e4b025594b11 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 15 May 2026 13:48:16 +0530 Subject: [PATCH 15/20] fix(server): require explicit observability namespace --- .../observability/ingest/base.py | 4 +- .../observability/ingest/direct.py | 9 +++-- .../observability/sinks.py | 8 ++-- .../observability/store/base.py | 9 ++--- .../observability/store/postgres.py | 8 ++-- server/tests/test_main_lifespan.py | 8 ++-- .../tests/test_observability_direct_ingest.py | 23 ++++++----- server/tests/test_observability_endpoints.py | 33 ++++++++-------- server/tests/test_observability_sinks.py | 4 +- .../test_observability_store_postgres.py | 39 +++++++++++++------ 10 files changed, 82 insertions(+), 63 deletions(-) diff --git a/server/src/agent_control_server/observability/ingest/base.py b/server/src/agent_control_server/observability/ingest/base.py index cc06c3b5..8fd5116a 100644 --- a/server/src/agent_control_server/observability/ingest/base.py +++ b/server/src/agent_control_server/observability/ingest/base.py @@ -10,8 +10,6 @@ from agent_control_models.observability import ControlExecutionEvent from pydantic import BaseModel, Field -from ...models import DEFAULT_NAMESPACE_KEY - class IngestResult(BaseModel): """Result of an event ingestion operation. @@ -46,7 +44,7 @@ async def ingest( self, events: list[ControlExecutionEvent], *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> IngestResult: """Ingest events. Returns counts of received/processed/dropped. diff --git a/server/src/agent_control_server/observability/ingest/direct.py b/server/src/agent_control_server/observability/ingest/direct.py index 3b63afa8..faa27197 100644 --- a/server/src/agent_control_server/observability/ingest/direct.py +++ b/server/src/agent_control_server/observability/ingest/direct.py @@ -15,7 +15,6 @@ from agent_control_models.observability import ControlExecutionEvent from agent_control_telemetry.sinks import AsyncControlEventSink -from ...models import DEFAULT_NAMESPACE_KEY from ..sinks import EventStoreControlEventSink from ..store.base import EventStore from .base import EventIngestor, IngestResult @@ -39,7 +38,7 @@ class DirectEventIngestor(EventIngestor): def __init__( self, - store: EventStore | AsyncControlEventSink, + store: EventStore | AsyncControlEventSink | EventStoreControlEventSink, log_to_stdout: bool = False, ): """Initialize the ingestor. @@ -49,7 +48,9 @@ def __init__( log_to_stdout: Whether to log events as structured JSON (default: False) """ if isinstance(store, EventStore): - self.sink: AsyncControlEventSink = EventStoreControlEventSink(store) + self.sink: AsyncControlEventSink | EventStoreControlEventSink = ( + EventStoreControlEventSink(store) + ) else: self.sink = store self.log_to_stdout = log_to_stdout @@ -58,7 +59,7 @@ async def ingest( self, events: list[ControlExecutionEvent], *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> IngestResult: """Ingest events by writing them directly to the configured sink. diff --git a/server/src/agent_control_server/observability/sinks.py b/server/src/agent_control_server/observability/sinks.py index 11273ce8..d1c91865 100644 --- a/server/src/agent_control_server/observability/sinks.py +++ b/server/src/agent_control_server/observability/sinks.py @@ -17,7 +17,6 @@ ) from agent_control_telemetry.sink_selection import SinkSelectionError -from ..models import DEFAULT_NAMESPACE_KEY from .store.base import EventStore _named_event_sink_factories: ControlEventSinkFactoryRegistry[ResolvedControlEventBackend] = ( @@ -35,7 +34,7 @@ async def write_events( self, events: Sequence[ControlExecutionEvent], *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> SinkResult: """Write events to the underlying store and report accepted/dropped counts.""" stored = await self.store.store(list(events), namespace_key=namespace_key) @@ -43,11 +42,14 @@ async def write_events( return SinkResult(accepted=stored, dropped=dropped) +ServerControlEventSink = AsyncControlEventSink | EventStoreControlEventSink + + @dataclass(frozen=True) class ResolvedControlEventBackend: """Server observability backend with aligned write and query dependencies.""" - sink: AsyncControlEventSink + sink: ServerControlEventSink event_store: EventStore diff --git a/server/src/agent_control_server/observability/store/base.py b/server/src/agent_control_server/observability/store/base.py index 4249fcdc..f7231f2d 100644 --- a/server/src/agent_control_server/observability/store/base.py +++ b/server/src/agent_control_server/observability/store/base.py @@ -25,8 +25,6 @@ ) from pydantic import BaseModel, Field -from ...models import DEFAULT_NAMESPACE_KEY - # Type alias for time range literals TimeRange = Literal["1m", "5m", "15m", "1h", "24h", "7d", "30d", "180d", "365d"] @@ -125,7 +123,7 @@ async def store( self, events: list[ControlExecutionEvent], *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> int: """Store raw events. @@ -143,10 +141,11 @@ async def query_stats( self, agent_name: str, time_range: timedelta, + *, control_id: int | None = None, include_timeseries: bool = False, bucket_size: timedelta | None = None, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> StatsResult: """Query stats (aggregated at query time from raw events). @@ -168,7 +167,7 @@ async def query_events( self, query: EventQuery, *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> EventQueryResult: """Query raw events with filters and pagination. diff --git a/server/src/agent_control_server/observability/store/postgres.py b/server/src/agent_control_server/observability/store/postgres.py index 39972f04..435f2ace 100644 --- a/server/src/agent_control_server/observability/store/postgres.py +++ b/server/src/agent_control_server/observability/store/postgres.py @@ -24,7 +24,6 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from ...models import DEFAULT_NAMESPACE_KEY from .base import EventStore, StatsResult logger = logging.getLogger(__name__) @@ -111,7 +110,7 @@ async def store( self, events: list[ControlExecutionEvent], *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> int: """Store raw events in PostgreSQL. @@ -170,10 +169,11 @@ async def query_stats( self, agent_name: str, time_range: timedelta, + *, control_id: int | None = None, include_timeseries: bool = False, bucket_size: timedelta | None = None, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> StatsResult: """Query stats aggregated at query time from raw events. @@ -398,7 +398,7 @@ async def query_events( self, query: EventQueryRequest, *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> EventQueryResponse: """Query raw events with filters and pagination. diff --git a/server/tests/test_main_lifespan.py b/server/tests/test_main_lifespan.py index e6e6f595..69cabd5e 100644 --- a/server/tests/test_main_lifespan.py +++ b/server/tests/test_main_lifespan.py @@ -33,13 +33,13 @@ class DummyStore: def __init__(self) -> None: self.closed = False - async def store(self, events): # type: ignore[no-untyped-def] + async def store(self, events, *, namespace_key: str): # type: ignore[no-untyped-def] return len(events) async def query_stats(self, *args, **kwargs): # type: ignore[no-untyped-def] raise NotImplementedError - async def query_events(self, query): # type: ignore[no-untyped-def] + async def query_events(self, query, *, namespace_key: str): # type: ignore[no-untyped-def] raise NotImplementedError async def close(self) -> None: @@ -103,13 +103,13 @@ def __init__(self) -> None: async def write_events(self, events): # type: ignore[no-untyped-def] raise NotImplementedError - async def store(self, events): # type: ignore[no-untyped-def] + async def store(self, events, *, namespace_key: str): # type: ignore[no-untyped-def] return len(events) async def query_stats(self, *args, **kwargs): # type: ignore[no-untyped-def] raise NotImplementedError - async def query_events(self, query): # type: ignore[no-untyped-def] + async def query_events(self, query, *, namespace_key: str): # type: ignore[no-untyped-def] raise NotImplementedError async def flush(self) -> None: diff --git a/server/tests/test_observability_direct_ingest.py b/server/tests/test_observability_direct_ingest.py index 13e7270c..e419b9e9 100644 --- a/server/tests/test_observability_direct_ingest.py +++ b/server/tests/test_observability_direct_ingest.py @@ -6,6 +6,7 @@ from agent_control_models.observability import ControlExecutionEvent from agent_control_telemetry.sinks import SinkResult +from agent_control_server.models import DEFAULT_NAMESPACE_KEY from agent_control_server.observability.ingest.direct import DirectEventIngestor from agent_control_server.observability.store.base import EventStore @@ -15,7 +16,7 @@ async def store( self, events: list[ControlExecutionEvent], *, - namespace_key: str = "default", + namespace_key: str, ) -> int: raise RuntimeError("boom") @@ -23,14 +24,15 @@ async def query_stats( self, agent_name, time_range, + *, control_id=None, include_timeseries=False, bucket_size=None, - namespace_key="default", + namespace_key, ): # pragma: no cover - not used raise NotImplementedError - async def query_events(self, query, *, namespace_key="default"): # pragma: no cover - not used + async def query_events(self, query, *, namespace_key): # pragma: no cover - not used raise NotImplementedError @@ -42,7 +44,7 @@ async def store( self, events: list[ControlExecutionEvent], *, - namespace_key: str = "default", + namespace_key: str, ) -> int: self.calls.append(events) return len(events) @@ -51,14 +53,15 @@ async def query_stats( self, agent_name, time_range, + *, control_id=None, include_timeseries=False, bucket_size=None, - namespace_key="default", + namespace_key, ): # pragma: no cover - not used raise NotImplementedError - async def query_events(self, query, *, namespace_key="default"): # pragma: no cover - not used + async def query_events(self, query, *, namespace_key): # pragma: no cover - not used raise NotImplementedError @@ -91,7 +94,7 @@ async def test_direct_ingestor_drops_on_store_error() -> None: ] # When: ingesting events - result = await ingestor.ingest(events) + result = await ingestor.ingest(events, namespace_key=DEFAULT_NAMESPACE_KEY) # Then: all events are dropped assert result.received == 1 @@ -119,7 +122,7 @@ async def test_direct_ingestor_logs_when_enabled(caplog: pytest.LogCaptureFixtur # When: ingesting events with caplog.at_level(logging.INFO): - result = await ingestor.ingest([event]) + result = await ingestor.ingest([event], namespace_key=DEFAULT_NAMESPACE_KEY) # Then: event is stored and a log line is emitted assert result.processed == 1 @@ -133,7 +136,7 @@ async def test_direct_ingestor_empty_events_returns_zeroes() -> None: ingestor = DirectEventIngestor(store=CountingStore()) # When: ingesting an empty list - result = await ingestor.ingest([]) + result = await ingestor.ingest([], namespace_key=DEFAULT_NAMESPACE_KEY) # Then: counts are zeroed assert result.received == 0 @@ -172,7 +175,7 @@ async def test_direct_ingestor_accepts_control_event_sink() -> None: ) ] - result = await ingestor.ingest(events) + result = await ingestor.ingest(events, namespace_key=DEFAULT_NAMESPACE_KEY) assert result.received == 1 assert result.processed == 1 diff --git a/server/tests/test_observability_endpoints.py b/server/tests/test_observability_endpoints.py index 114713c7..c6b722e8 100644 --- a/server/tests/test_observability_endpoints.py +++ b/server/tests/test_observability_endpoints.py @@ -17,6 +17,7 @@ from agent_control_server.auth_framework import Operation, Principal, set_authorizer from agent_control_server.main import app +from agent_control_server.models import DEFAULT_NAMESPACE_KEY from agent_control_server.observability.ingest.base import IngestResult @@ -298,7 +299,7 @@ async def test_store_events(self, setup_observability): store = setup_observability events = [create_test_event(i) for i in range(5)] - stored = await store.store(events) + stored = await store.store(events, namespace_key=DEFAULT_NAMESPACE_KEY) assert stored == 5 @pytest.mark.asyncio @@ -307,9 +308,9 @@ async def test_store_deduplicates_events(self, setup_observability): store = setup_observability event = create_test_event() - await store.store([event]) + await store.store([event], namespace_key=DEFAULT_NAMESPACE_KEY) # Storing same event again should not raise, but also not duplicate - stored = await store.store([event]) + stored = await store.store([event], namespace_key=DEFAULT_NAMESPACE_KEY) # ON CONFLICT DO NOTHING returns the batch size, not actual inserts assert stored == 1 @@ -322,7 +323,7 @@ async def test_ingest_via_direct_ingestor(self, setup_observability): ingestor = DirectEventIngestor(store) events = [create_test_event(i) for i in range(3)] - result = await ingestor.ingest(events) + result = await ingestor.ingest(events, namespace_key=DEFAULT_NAMESPACE_KEY) assert result.received == 3 assert result.processed == 3 @@ -341,7 +342,7 @@ async def test_stats_normalize_mixed_case_agent_name_query( normalized_name = "agent-statsnorm01" event = create_test_event(agent_name=normalized_name, matched=True) - await store.store([event]) + await store.store([event], namespace_key=DEFAULT_NAMESPACE_KEY) response = client.get( "/api/v1/observability/stats", @@ -361,7 +362,7 @@ async def test_stats_without_timeseries(self, client: TestClient, setup_observab # Create and store an event event = create_test_event(agent_name=agent_name, matched=True) - await store.store([event]) + await store.store([event], namespace_key=DEFAULT_NAMESPACE_KEY) response = client.get( "/api/v1/observability/stats", @@ -403,7 +404,7 @@ async def test_stats_with_timeseries(self, client: TestClient, setup_observabili ), ] - await store.store(events) + await store.store(events, namespace_key=DEFAULT_NAMESPACE_KEY) response = client.get( "/api/v1/observability/stats", @@ -444,7 +445,7 @@ async def test_timeseries_bucket_count_1h(self, client: TestClient, setup_observ timestamp=now - timedelta(minutes=30), ) - await store.store([event]) + await store.store([event], namespace_key=DEFAULT_NAMESPACE_KEY) response = client.get( "/api/v1/observability/stats", @@ -475,7 +476,7 @@ async def test_timeseries_bucket_count_5m(self, client: TestClient, setup_observ timestamp=now - timedelta(minutes=2), ) - await store.store([event]) + await store.store([event], namespace_key=DEFAULT_NAMESPACE_KEY) response = client.get( "/api/v1/observability/stats", @@ -526,7 +527,7 @@ async def test_timeseries_aggregates_events_per_bucket( ), ] - await store.store(events) + await store.store(events, namespace_key=DEFAULT_NAMESPACE_KEY) response = client.get( "/api/v1/observability/stats", @@ -573,7 +574,7 @@ async def test_timeseries_empty_buckets_included(self, client: TestClient, setup timestamp=now - timedelta(minutes=55), ) - await store.store([event]) + await store.store([event], namespace_key=DEFAULT_NAMESPACE_KEY) response = client.get( "/api/v1/observability/stats", @@ -621,7 +622,7 @@ async def test_control_stats_basic(self, client: TestClient, setup_observability create_test_event(control_id=1, agent_name=agent_name, matched=True, action="deny"), create_test_event(control_id=2, agent_name=agent_name, matched=True, action="observe"), ] - await store.store(events) + await store.store(events, namespace_key=DEFAULT_NAMESPACE_KEY) # Get stats for control 1 only response = client.get( @@ -677,7 +678,7 @@ async def test_control_stats_with_timeseries(self, client: TestClient, setup_obs timestamp=now - timedelta(minutes=20), ), ] - await store.store(events) + await store.store(events, namespace_key=DEFAULT_NAMESPACE_KEY) response = client.get( "/api/v1/observability/stats/controls/1", @@ -710,7 +711,7 @@ async def test_control_stats_no_data(self, client: TestClient, setup_observabili # Create event for control 1 only event = create_test_event(control_id=1, agent_name=agent_name, matched=True) - await store.store([event]) + await store.store([event], namespace_key=DEFAULT_NAMESPACE_KEY) # Query for control 2 (no events) response = client.get( @@ -903,7 +904,7 @@ def test_ingest_events_partial_status(self, client: TestClient, setup_observabil # Given: a stub ingestor that drops some events class StubIngestor: - async def ingest(self, events, *, namespace_key="default"): + async def ingest(self, events, *, namespace_key: str): return IngestResult(received=len(events), processed=1, dropped=len(events) - 1) original = app.state.event_ingestor @@ -933,7 +934,7 @@ def test_ingest_events_failed_status(self, client: TestClient, setup_observabili # Given: a stub ingestor that drops all events class StubIngestor: - async def ingest(self, events, *, namespace_key="default"): + async def ingest(self, events, *, namespace_key: str): return IngestResult(received=len(events), processed=0, dropped=len(events)) original = app.state.event_ingestor diff --git a/server/tests/test_observability_sinks.py b/server/tests/test_observability_sinks.py index 38122f1b..ff2a66c8 100644 --- a/server/tests/test_observability_sinks.py +++ b/server/tests/test_observability_sinks.py @@ -21,13 +21,13 @@ async def write_events(self, events): # type: ignore[no-untyped-def] class DummyStore: - async def store(self, events): # type: ignore[no-untyped-def] + async def store(self, events, *, namespace_key: str): # type: ignore[no-untyped-def] return len(events) async def query_stats(self, *args, **kwargs): # type: ignore[no-untyped-def] raise NotImplementedError - async def query_events(self, query): # type: ignore[no-untyped-def] + async def query_events(self, query, *, namespace_key: str): # type: ignore[no-untyped-def] raise NotImplementedError async def close(self) -> None: diff --git a/server/tests/test_observability_store_postgres.py b/server/tests/test_observability_store_postgres.py index 2786130b..9c5942f0 100644 --- a/server/tests/test_observability_store_postgres.py +++ b/server/tests/test_observability_store_postgres.py @@ -9,6 +9,7 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from agent_control_server.models import DEFAULT_NAMESPACE_KEY from agent_control_server.observability.store.postgres import PostgresEventStore from .conftest import async_engine, engine @@ -91,24 +92,26 @@ async def test_postgres_event_store_query_events_and_stats() -> None: ] # When: storing events - await store.store(events) + await store.store(events, namespace_key=DEFAULT_NAMESPACE_KEY) # When: querying events filtered by control_id query = EventQueryRequest(agent_name=agent_name, control_ids=[1], limit=10, offset=0) - resp = await store.query_events(query) + resp = await store.query_events(query, namespace_key=DEFAULT_NAMESPACE_KEY) # Then: only matching events are returned assert resp.total == 2 assert all(e.control_id == 1 for e in resp.events) # When: querying events filtered by trace_id query = EventQueryRequest(trace_id="a" * 32, limit=10, offset=0) - resp = await store.query_events(query) + resp = await store.query_events(query, namespace_key=DEFAULT_NAMESPACE_KEY) # Then: only matching events are returned assert resp.total == 2 assert all(e.trace_id == "a" * 32 for e in resp.events) # When: querying stats - stats = await store.query_stats(agent_name, timedelta(hours=1)) + stats = await store.query_stats( + agent_name, timedelta(hours=1), namespace_key=DEFAULT_NAMESPACE_KEY + ) # Then: totals and action counts are aggregated correctly assert stats.total_executions == 3 assert stats.total_matches == 2 @@ -117,7 +120,12 @@ async def test_postgres_event_store_query_events_and_stats() -> None: assert stats.action_counts == {"observe": 2} # When: querying stats with a control filter - filtered_stats = await store.query_stats(agent_name, timedelta(hours=1), control_id=1) + filtered_stats = await store.query_stats( + agent_name, + timedelta(hours=1), + control_id=1, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) # Then: only the requested control is returned assert len(filtered_stats.stats) == 1 assert filtered_stats.stats[0].control_id == 1 @@ -178,7 +186,7 @@ async def test_postgres_event_store_store_empty_returns_zero() -> None: store = PostgresEventStore(session_maker) # When: storing an empty event list - stored = await store.store([]) + stored = await store.store([], namespace_key=DEFAULT_NAMESPACE_KEY) # Then: zero events are reported as stored assert stored == 0 @@ -229,7 +237,7 @@ async def test_postgres_event_store_query_events_all_filters() -> None: ), ] - await store.store(events) + await store.store(events, namespace_key=DEFAULT_NAMESPACE_KEY) # When: querying with all supported filters query = EventQueryRequest( @@ -247,7 +255,7 @@ async def test_postgres_event_store_query_events_all_filters() -> None: limit=10, offset=0, ) - resp = await store.query_events(query) + resp = await store.query_events(query, namespace_key=DEFAULT_NAMESPACE_KEY) # Then: only the matching event is returned assert resp.total == 1 @@ -297,9 +305,12 @@ async def test_postgres_event_store_normalizes_legacy_advisory_rows() -> None: # When: querying with the canonical observe filter resp = await store.query_events( - EventQueryRequest(agent_name=agent_name, actions=["observe"], limit=10, offset=0) + EventQueryRequest(agent_name=agent_name, actions=["observe"], limit=10, offset=0), + namespace_key=DEFAULT_NAMESPACE_KEY, + ) + stats = await store.query_stats( + agent_name, timedelta(hours=1), namespace_key=DEFAULT_NAMESPACE_KEY ) - stats = await store.query_stats(agent_name, timedelta(hours=1)) # Then: the legacy row is returned and normalized to observe assert resp.total == 1 @@ -348,7 +359,7 @@ async def test_postgres_event_store_timeseries_includes_steer_and_observe_counts ] # When: storing events - await store.store(events) + await store.store(events, namespace_key=DEFAULT_NAMESPACE_KEY) # When: querying stats with timeseries enabled stats = await store.query_stats( @@ -356,6 +367,7 @@ async def test_postgres_event_store_timeseries_includes_steer_and_observe_counts time_range=timedelta(hours=1), include_timeseries=True, bucket_size=timedelta(minutes=1), + namespace_key=DEFAULT_NAMESPACE_KEY, ) # Then: action counts include steer and observe @@ -428,7 +440,10 @@ def __call__(self): # type: ignore[no-untyped-def] store = PostgresEventStore(DummySessionMaker(rows)) # When: querying events - resp = await store.query_events(EventQueryRequest(limit=10, offset=0)) + resp = await store.query_events( + EventQueryRequest(limit=10, offset=0), + namespace_key=DEFAULT_NAMESPACE_KEY, + ) # Then: the JSON string is parsed into ControlExecutionEvent assert resp.total == 1 From 50fd779be75f37a6ecd344965ea708c455c5bf96 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 15 May 2026 14:22:19 +0530 Subject: [PATCH 16/20] fix(server): serialize alembic migrations --- server/src/agent_control_server/migrate.py | 128 ++++++++++++++++++--- server/tests/test_migrate.py | 91 +++++++++++++++ 2 files changed, 206 insertions(+), 13 deletions(-) diff --git a/server/src/agent_control_server/migrate.py b/server/src/agent_control_server/migrate.py index f861c8db..657711c1 100644 --- a/server/src/agent_control_server/migrate.py +++ b/server/src/agent_control_server/migrate.py @@ -9,18 +9,36 @@ import argparse import logging +import os import shutil import sys import tempfile +import time from collections.abc import Iterator from contextlib import contextmanager from pathlib import Path from typing import cast -from alembic import command from alembic.config import Config +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Connection +from sqlalchemy.engine.url import make_url +from sqlalchemy.pool import NullPool import agent_control_server +from agent_control_server.config import db_config +from alembic import command + +LOGGER = logging.getLogger(__name__) +_MIGRATION_LOCK_CLASS_ID = 0x4143544C # "ACTL" +_MIGRATION_LOCK_OBJECT_ID = 0x4D494752 # "MIGR" +_MIGRATION_LOCK_POLL_SECONDS = 2.0 +_DEFAULT_MIGRATION_LOCK_TIMEOUT_SECONDS = 600.0 +_MIGRATION_LOCK_TIMEOUT_ENV = "AGENT_CONTROL_MIGRATION_LOCK_TIMEOUT_SECONDS" +_MIGRATION_LOCK_PARAMS = { + "class_id": _MIGRATION_LOCK_CLASS_ID, + "object_id": _MIGRATION_LOCK_OBJECT_ID, +} def _bundled_config() -> Config: @@ -58,6 +76,88 @@ def _runtime_bundled_config() -> Iterator[Config]: yield cfg +def _migration_url(cfg: Config) -> str: + configured_url = cfg.get_main_option("sqlalchemy.url") + if configured_url: + return configured_url + return db_config.get_url() + + +def _migration_lock_timeout_seconds() -> float: + raw_timeout = os.getenv(_MIGRATION_LOCK_TIMEOUT_ENV) + if raw_timeout is None: + return _DEFAULT_MIGRATION_LOCK_TIMEOUT_SECONDS + + try: + timeout = float(raw_timeout) + except ValueError as exc: + raise RuntimeError(f"{_MIGRATION_LOCK_TIMEOUT_ENV} must be a number.") from exc + + if timeout <= 0: + raise RuntimeError(f"{_MIGRATION_LOCK_TIMEOUT_ENV} must be greater than zero.") + return timeout + + +def _acquire_migration_lock(connection: Connection, timeout_seconds: float) -> None: + deadline = time.monotonic() + timeout_seconds + logged_wait = False + + while True: + acquired = bool( + connection.execute( + text("SELECT pg_try_advisory_lock(:class_id, :object_id)"), + _MIGRATION_LOCK_PARAMS, + ).scalar_one() + ) + if acquired: + LOGGER.info("Acquired Agent Control migration advisory lock.") + return + + remaining = deadline - time.monotonic() + if remaining <= 0: + raise TimeoutError( + f"Timed out after {timeout_seconds:g}s waiting for Agent Control " + "migration advisory lock." + ) + + if not logged_wait: + LOGGER.info("Waiting for another Agent Control migration to finish.") + logged_wait = True + time.sleep(min(_MIGRATION_LOCK_POLL_SECONDS, remaining)) + + +@contextmanager +def _serialized_migration(cfg: Config, *, enabled: bool) -> Iterator[None]: + if not enabled: + yield + return + + url = _migration_url(cfg) + if make_url(url).get_backend_name() != "postgresql": + yield + return + + engine = create_engine(url, future=True, poolclass=NullPool) + try: + with engine.connect() as connection: + _acquire_migration_lock(connection, _migration_lock_timeout_seconds()) + try: + yield + finally: + released = bool( + connection.execute( + text("SELECT pg_advisory_unlock(:class_id, :object_id)"), + _MIGRATION_LOCK_PARAMS, + ).scalar_one() + ) + if released: + LOGGER.info("Released Agent Control migration advisory lock.") + else: + LOGGER.warning("Agent Control migration advisory lock was not held at release.") + finally: + engine.dispose() + + def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( prog="agent-control-migrate", @@ -104,18 +204,20 @@ def main(argv: list[str] | None = None) -> int: try: with _runtime_bundled_config() as cfg: - if parsed.command == "upgrade": - command.upgrade(cfg, parsed.revision, sql=parsed.sql) - elif parsed.command == "downgrade": - command.downgrade(cfg, parsed.revision, sql=parsed.sql) - elif parsed.command == "current": - command.current(cfg) - elif parsed.command == "history": - command.history(cfg) - elif parsed.command == "heads": - command.heads(cfg) - else: # pragma: no cover - argparse guarantees this cannot happen. - parser.error("missing command") + should_lock = parsed.command in {"upgrade", "downgrade"} and not parsed.sql + with _serialized_migration(cfg, enabled=should_lock): + if parsed.command == "upgrade": + command.upgrade(cfg, parsed.revision, sql=parsed.sql) + elif parsed.command == "downgrade": + command.downgrade(cfg, parsed.revision, sql=parsed.sql) + elif parsed.command == "current": + command.current(cfg) + elif parsed.command == "history": + command.history(cfg) + elif parsed.command == "heads": + command.heads(cfg) + else: # pragma: no cover - argparse guarantees this cannot happen. + parser.error("missing command") except Exception as exc: print(f"agent-control-migrate: {exc}", file=sys.stderr) return 1 diff --git a/server/tests/test_migrate.py b/server/tests/test_migrate.py index c6430e0e..d2433cc9 100644 --- a/server/tests/test_migrate.py +++ b/server/tests/test_migrate.py @@ -2,10 +2,53 @@ from pathlib import Path +from alembic.config import Config + import agent_control_server from agent_control_server import migrate +class _FakeResult: + def __init__(self, value: bool) -> None: + self.value = value + + def scalar_one(self) -> bool: + return self.value + + +class _FakeConnection: + def __init__(self, lock_results: list[bool]) -> None: + self.lock_results = lock_results + self.statements: list[str] = [] + + def __enter__(self) -> _FakeConnection: + return self + + def __exit__(self, *args: object) -> None: + return None + + def execute(self, statement: object, params: object) -> _FakeResult: + statement_text = str(statement) + self.statements.append(statement_text) + if "pg_try_advisory_lock" in statement_text: + return _FakeResult(self.lock_results.pop(0)) + if "pg_advisory_unlock" in statement_text: + return _FakeResult(True) + raise AssertionError(f"unexpected SQL statement: {statement_text}") + + +class _FakeEngine: + def __init__(self, connection: _FakeConnection) -> None: + self.connection = connection + self.disposed = False + + def connect(self) -> _FakeConnection: + return self.connection + + def dispose(self) -> None: + self.disposed = True + + def test_bundled_config_omits_injected_version_init( tmp_path: Path, monkeypatch, @@ -31,3 +74,51 @@ def test_bundled_config_omits_injected_version_init( assert not (script_location / "versions" / "__init__.py").exists() assert not script_location.exists() + + +def test_serialized_migration_skips_lock_for_non_postgres_url(monkeypatch) -> None: + cfg = Config() + cfg.set_main_option("sqlalchemy.url", "sqlite:///agent-control.db") + + def fail_create_engine(*args: object, **kwargs: object) -> object: + raise AssertionError("non-postgres migrations should not create a lock connection") + + monkeypatch.setattr(migrate, "create_engine", fail_create_engine) + + with migrate._serialized_migration(cfg, enabled=True): + pass + + +def test_serialized_migration_acquires_and_releases_postgres_lock(monkeypatch) -> None: + cfg = Config() + cfg.set_main_option("sqlalchemy.url", "postgresql+psycopg://user:pass@postgres/db") + connection = _FakeConnection([False, True]) + engine = _FakeEngine(connection) + sleeps: list[float] = [] + + monkeypatch.setattr(migrate, "create_engine", lambda *args, **kwargs: engine) + monkeypatch.setattr(migrate.time, "sleep", lambda seconds: sleeps.append(seconds)) + + with migrate._serialized_migration(cfg, enabled=True): + pass + + assert connection.statements == [ + "SELECT pg_try_advisory_lock(:class_id, :object_id)", + "SELECT pg_try_advisory_lock(:class_id, :object_id)", + "SELECT pg_advisory_unlock(:class_id, :object_id)", + ] + assert sleeps == [2.0] + assert engine.disposed + + +def test_serialized_migration_respects_disabled_lock(monkeypatch) -> None: + cfg = Config() + cfg.set_main_option("sqlalchemy.url", "postgresql+psycopg://user:pass@postgres/db") + + def fail_create_engine(*args: object, **kwargs: object) -> object: + raise AssertionError("disabled migration lock should not create a lock connection") + + monkeypatch.setattr(migrate, "create_engine", fail_create_engine) + + with migrate._serialized_migration(cfg, enabled=False): + pass From 8ce320a03a90031aacd42c94a0e5a7c41bed1967 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 15 May 2026 14:41:43 +0530 Subject: [PATCH 17/20] fix(server): update migration dispatch test config --- server/src/agent_control_server/migrate.py | 2 +- server/tests/test_migrate.py | 3 +- server/unit_tests/test_migrate.py | 38 +++++++++++++--------- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/server/src/agent_control_server/migrate.py b/server/src/agent_control_server/migrate.py index 657711c1..3f260d4a 100644 --- a/server/src/agent_control_server/migrate.py +++ b/server/src/agent_control_server/migrate.py @@ -19,6 +19,7 @@ from pathlib import Path from typing import cast +from alembic import command from alembic.config import Config from sqlalchemy import create_engine, text from sqlalchemy.engine import Connection @@ -27,7 +28,6 @@ import agent_control_server from agent_control_server.config import db_config -from alembic import command LOGGER = logging.getLogger(__name__) _MIGRATION_LOCK_CLASS_ID = 0x4143544C # "ACTL" diff --git a/server/tests/test_migrate.py b/server/tests/test_migrate.py index d2433cc9..eaed9798 100644 --- a/server/tests/test_migrate.py +++ b/server/tests/test_migrate.py @@ -2,10 +2,9 @@ from pathlib import Path -from alembic.config import Config - import agent_control_server from agent_control_server import migrate +from alembic.config import Config class _FakeResult: diff --git a/server/unit_tests/test_migrate.py b/server/unit_tests/test_migrate.py index 874b21a5..ea5ded3c 100644 --- a/server/unit_tests/test_migrate.py +++ b/server/unit_tests/test_migrate.py @@ -9,24 +9,32 @@ from __future__ import annotations import tomllib +from collections.abc import Iterator +from contextlib import contextmanager from pathlib import Path, PurePosixPath from unittest.mock import MagicMock import pytest -from alembic.script import ScriptDirectory - from agent_control_server import migrate +from alembic.config import Config +from alembic.script import ScriptDirectory @pytest.fixture -def stub_config(monkeypatch: pytest.MonkeyPatch) -> object: - """Replace bundled-config building with a sentinel object. +def stub_config(monkeypatch: pytest.MonkeyPatch) -> Config: + """Replace runtime config building with a lightweight Alembic config. Lets dispatch tests verify which Alembic command was called and - what config was passed without needing real migration assets. + what config was passed without touching real migration assets or DB locks. """ - sentinel = object() - monkeypatch.setattr(migrate, "_bundled_config", lambda: sentinel) + sentinel = Config() + sentinel.set_main_option("sqlalchemy.url", "sqlite:///agent-control-test.db") + + @contextmanager + def fake_runtime_bundled_config() -> Iterator[Config]: + yield sentinel + + monkeypatch.setattr(migrate, "_runtime_bundled_config", fake_runtime_bundled_config) return sentinel @@ -37,7 +45,7 @@ def _patch_command(monkeypatch: pytest.MonkeyPatch, name: str) -> MagicMock: def test_main_default_runs_upgrade_head( - stub_config: object, monkeypatch: pytest.MonkeyPatch + stub_config: Config, monkeypatch: pytest.MonkeyPatch ) -> None: upgrade = _patch_command(monkeypatch, "upgrade") rc = migrate.main([]) @@ -46,7 +54,7 @@ def test_main_default_runs_upgrade_head( def test_main_bare_upgrade_runs_upgrade_head( - stub_config: object, monkeypatch: pytest.MonkeyPatch + stub_config: Config, monkeypatch: pytest.MonkeyPatch ) -> None: upgrade = _patch_command(monkeypatch, "upgrade") rc = migrate.main(["upgrade"]) @@ -55,7 +63,7 @@ def test_main_bare_upgrade_runs_upgrade_head( def test_main_explicit_upgrade_revision( - stub_config: object, monkeypatch: pytest.MonkeyPatch + stub_config: Config, monkeypatch: pytest.MonkeyPatch ) -> None: upgrade = _patch_command(monkeypatch, "upgrade") rc = migrate.main(["upgrade", "abc123"]) @@ -64,7 +72,7 @@ def test_main_explicit_upgrade_revision( def test_main_upgrade_supports_sql( - stub_config: object, monkeypatch: pytest.MonkeyPatch + stub_config: Config, monkeypatch: pytest.MonkeyPatch ) -> None: upgrade = _patch_command(monkeypatch, "upgrade") rc = migrate.main(["upgrade", "head", "--sql"]) @@ -82,7 +90,7 @@ def test_main_bare_downgrade_requires_explicit_revision( def test_main_explicit_downgrade_revision( - stub_config: object, monkeypatch: pytest.MonkeyPatch + stub_config: Config, monkeypatch: pytest.MonkeyPatch ) -> None: downgrade = _patch_command(monkeypatch, "downgrade") rc = migrate.main(["downgrade", "abc123"]) @@ -91,7 +99,7 @@ def test_main_explicit_downgrade_revision( def test_main_downgrade_supports_sql( - stub_config: object, monkeypatch: pytest.MonkeyPatch + stub_config: Config, monkeypatch: pytest.MonkeyPatch ) -> None: downgrade = _patch_command(monkeypatch, "downgrade") rc = migrate.main(["downgrade", "-1", "--sql"]) @@ -101,7 +109,7 @@ def test_main_downgrade_supports_sql( @pytest.mark.parametrize("op", ["current", "history", "heads"]) def test_main_query_commands( - stub_config: object, monkeypatch: pytest.MonkeyPatch, op: str + stub_config: Config, monkeypatch: pytest.MonkeyPatch, op: str ) -> None: cmd = _patch_command(monkeypatch, op) rc = migrate.main([op]) @@ -147,7 +155,7 @@ def test_main_rejects_extra_positional_args(monkeypatch: pytest.MonkeyPatch) -> def test_main_returns_nonzero_for_command_errors( - stub_config: object, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] + stub_config: Config, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] ) -> None: upgrade = _patch_command(monkeypatch, "upgrade") upgrade.side_effect = RuntimeError("database unavailable") From 3c46474b4efe86137019aaeb4da6bee394266017 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 15 May 2026 17:02:09 +0530 Subject: [PATCH 18/20] fix(server): harden namespace rollout edges --- ...c2d8e9a1_namespace_observability_events.py | 40 ++++++++++--- .../auth_framework/config.py | 7 ++- .../auth_framework/providers/http_upstream.py | 12 +++- .../agent_control_server/endpoints/agents.py | 7 +++ server/src/agent_control_server/models.py | 9 ++- .../observability/store/postgres.py | 2 +- server/tests/test_auth_framework.py | 56 ++++++++++++++++++- .../test_data_model_v1_alembic_migration.py | 16 ++++++ server/tests/test_init_agent_conflict_mode.py | 29 ++++++++++ .../test_observability_store_postgres.py | 42 ++++++++++++++ 10 files changed, 204 insertions(+), 16 deletions(-) diff --git a/server/alembic/versions/b6f4c2d8e9a1_namespace_observability_events.py b/server/alembic/versions/b6f4c2d8e9a1_namespace_observability_events.py index fe769116..c4bb1951 100644 --- a/server/alembic/versions/b6f4c2d8e9a1_namespace_observability_events.py +++ b/server/alembic/versions/b6f4c2d8e9a1_namespace_observability_events.py @@ -28,17 +28,43 @@ def upgrade() -> None: nullable=False, ), ) - op.create_index( - "ix_events_namespace_agent_time", + op.drop_constraint( + "control_execution_events_pkey", "control_execution_events", - ["namespace_key", "agent_name", sa.literal_column("timestamp DESC")], - unique=False, + type_="primary", ) + op.create_primary_key( + "control_execution_events_pkey", + "control_execution_events", + ["namespace_key", "control_execution_id"], + ) + with op.get_context().autocommit_block(): + op.execute("DROP INDEX CONCURRENTLY IF EXISTS ix_events_agent_time") + op.execute( + """ + CREATE INDEX CONCURRENTLY IF NOT EXISTS ix_events_namespace_agent_time + ON control_execution_events (namespace_key, agent_name, timestamp DESC) + """ + ) def downgrade() -> None: - op.drop_index( - "ix_events_namespace_agent_time", - table_name="control_execution_events", + with op.get_context().autocommit_block(): + op.execute("DROP INDEX CONCURRENTLY IF EXISTS ix_events_namespace_agent_time") + op.execute( + """ + CREATE INDEX CONCURRENTLY IF NOT EXISTS ix_events_agent_time + ON control_execution_events (agent_name, timestamp DESC) + """ + ) + op.drop_constraint( + "control_execution_events_pkey", + "control_execution_events", + type_="primary", + ) + op.create_primary_key( + "control_execution_events_pkey", + "control_execution_events", + ["control_execution_id"], ) op.drop_column("control_execution_events", "namespace_key") diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 559ba425..8d8fbcd9 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -232,12 +232,12 @@ def _build_default_provider() -> RequestAuthorizer: ) -def _validate_local_api_key_mode() -> None: +def _validate_local_api_key_mode(mode_env: str = _MODE_ENV) -> None: """Fail startup when local API-key mode has no local key validator.""" if not auth_settings.api_key_enabled: raise RuntimeError( - f"{_MODE_ENV}=api_key requires AGENT_CONTROL_API_KEY_ENABLED=true. " - f"Use {_MODE_ENV}=none for deployments without credential enforcement." + f"{mode_env}=api_key requires AGENT_CONTROL_API_KEY_ENABLED=true. " + f"Use {mode_env}=none for deployments without credential enforcement." ) if not auth_settings.get_api_keys() and not auth_settings.get_admin_api_keys(): raise RuntimeError( @@ -295,6 +295,7 @@ def _build_runtime_provider( if mode == "none": return NoAuthProvider() if mode == "api_key": + _validate_local_api_key_mode(_RUNTIME_MODE_ENV) return HeaderAuthProvider() if mode == "jwt": if config is None: diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index 27c776bd..4223a5fd 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -336,12 +336,20 @@ def _ensure_target_context_matches_grant( if principal.target_type is None and principal.target_id is None: return if context is None: - return + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="Authorization grant is target-bound but the request target is unavailable.", + hint="Use an endpoint that includes target_type and target_id in the authorization context.", + ) expected_type = context.get("target_type") expected_id = context.get("target_id") if not isinstance(expected_type, str) or not isinstance(expected_id, str): - return + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="Authorization grant is target-bound but the request target is incomplete.", + hint="Provide both target_type and target_id for target-bound credentials.", + ) if principal.target_type == expected_type and principal.target_id == expected_id: return diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index d29fbdfc..1d8efe4b 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -912,6 +912,13 @@ async def init_agent( data_model.evaluators = new_evaluators + if ( + not request.force_replace + and request.conflict_mode != ConflictMode.OVERWRITE + and (steps_changed or evaluators_changed or metadata_changed) + ): + await _authorize_existing_agent_overwrite(http_request, principal) + if steps_changed or evaluators_changed or metadata_changed or force_write: existing.data = data_model.model_dump(mode="json") diff --git a/server/src/agent_control_server/models.py b/server/src/agent_control_server/models.py index a0dfc0ed..cad73c23 100644 --- a/server/src/agent_control_server/models.py +++ b/server/src/agent_control_server/models.py @@ -14,6 +14,7 @@ ForeignKeyConstraint, Index, Integer, + PrimaryKeyConstraint, String, Table, Text, @@ -354,7 +355,7 @@ class ControlExecutionEventDB(Base): # Primary key control_execution_id: Mapped[str] = mapped_column( - String(36), primary_key=True + String(36) ) # Minimal indexed columns for efficient queries @@ -377,7 +378,11 @@ class ControlExecutionEventDB(Base): # Composite index for agent + time queries (primary access pattern) __table_args__ = ( + PrimaryKeyConstraint( + "namespace_key", + "control_execution_id", + name="control_execution_events_pkey", + ), Index("ix_events_namespace_agent_time", "namespace_key", "agent_name", timestamp.desc()), - Index("ix_events_agent_time", "agent_name", timestamp.desc()), Index("ix_events_data_control_id", text("(data ->> 'control_id'::text)")), ) diff --git a/server/src/agent_control_server/observability/store/postgres.py b/server/src/agent_control_server/observability/store/postgres.py index 435f2ace..b0c7f066 100644 --- a/server/src/agent_control_server/observability/store/postgres.py +++ b/server/src/agent_control_server/observability/store/postgres.py @@ -156,7 +156,7 @@ async def store( :namespace_key, :control_execution_id, :timestamp, :agent_name, CAST(:data AS JSONB) ) - ON CONFLICT (control_execution_id) DO NOTHING + ON CONFLICT (namespace_key, control_execution_id) DO NOTHING """), values, ) diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 83276744..a95f0252 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -1077,7 +1077,11 @@ async def test_http_upstream_accepts_iso_datetime_and_array_scopes(): }, ) ) - principal = await provider.authorize(_build_request(), Operation.RUNTIME_TOKEN_EXCHANGE) + principal = await provider.authorize( + _build_request(), + Operation.RUNTIME_TOKEN_EXCHANGE, + context={"target_type": "log_stream", "target_id": "ls-1"}, + ) assert principal.namespace_key == "org-1" assert principal.scopes == ("runtime.use", "runtime.read_only") assert principal.target_type == "log_stream" @@ -1107,6 +1111,44 @@ async def test_http_upstream_rejects_target_grant_mismatch(): ) +@pytest.mark.asyncio +async def test_http_upstream_rejects_target_grant_without_context(): + provider = _build_upstream( + lambda req: httpx.Response( + 200, + json={ + "namespace_key": "org-1", + "target_type": "log_stream", + "target_id": "bound", + }, + ) + ) + + with pytest.raises(ForbiddenError, match="request target is unavailable"): + await provider.authorize(_build_request(), Operation.CONTROL_BINDINGS_READ) + + +@pytest.mark.asyncio +async def test_http_upstream_rejects_target_grant_with_incomplete_context(): + provider = _build_upstream( + lambda req: httpx.Response( + 200, + json={ + "namespace_key": "org-1", + "target_type": "log_stream", + "target_id": "bound", + }, + ) + ) + + with pytest.raises(ForbiddenError, match="request target is incomplete"): + await provider.authorize( + _build_request(), + Operation.CONTROL_BINDINGS_READ, + context={"target_type": "log_stream"}, + ) + + # --------------------------------------------------------------------------- # configure_auth_from_env / teardown_auth lifecycle # --------------------------------------------------------------------------- @@ -1248,6 +1290,18 @@ def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): assert auth_config.runtime_auth_config() is None +def test_configure_runtime_api_key_rejects_without_validator(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "api_key") + monkeypatch.setattr(auth_settings, "api_key_enabled", False) + + with pytest.raises(RuntimeError, match="AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key"): + auth_config.configure_auth_from_env() + + def test_configure_runtime_unset_preserves_no_auth_default(monkeypatch): from agent_control_server.auth_framework import config as auth_config diff --git a/server/tests/test_data_model_v1_alembic_migration.py b/server/tests/test_data_model_v1_alembic_migration.py index 242fa1d7..7ca22e3a 100644 --- a/server/tests/test_data_model_v1_alembic_migration.py +++ b/server/tests/test_data_model_v1_alembic_migration.py @@ -16,6 +16,7 @@ SERVER_DIR = Path(__file__).resolve().parents[1] PRE_MIGRATION_REVISION = "c1e9f9c4a1d2" MIGRATION_REVISION = "a7f3b1e0d9c5" +OBSERVABILITY_NAMESPACE_REVISION = "b6f4c2d8e9a1" _BASE_DB_URL = make_url(db_config.get_url()) pytestmark = pytest.mark.skipif( @@ -223,6 +224,21 @@ def test_downgrade_round_trip(alembic_config: Config, temp_engine: Engine) -> No assert "control_bindings" in inspect(temp_engine).get_table_names() +def test_observability_namespace_migration_scopes_event_primary_key( + alembic_config: Config, temp_engine: Engine +) -> None: + command.upgrade(alembic_config, OBSERVABILITY_NAMESPACE_REVISION) + + assert "namespace_key" in _column_names(temp_engine, "control_execution_events") + assert _pk_columns(temp_engine, "control_execution_events") == [ + "namespace_key", + "control_execution_id", + ] + indexes = _index_names(temp_engine, "control_execution_events") + assert "ix_events_namespace_agent_time" in indexes + assert "ix_events_agent_time" not in indexes + + def test_downgrade_rejects_cross_namespace_agents_duplicates( alembic_config: Config, temp_engine: Engine ) -> None: diff --git a/server/tests/test_init_agent_conflict_mode.py b/server/tests/test_init_agent_conflict_mode.py index 2e8b9b80..0397ce94 100644 --- a/server/tests/test_init_agent_conflict_mode.py +++ b/server/tests/test_init_agent_conflict_mode.py @@ -212,6 +212,35 @@ def test_init_agent_force_replace_existing_agent_requires_update_auth( assert force_resp.status_code == 403 +def test_init_agent_strict_existing_agent_mutation_requires_update_auth( + client: TestClient, +) -> None: + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + create_resp = client.post( + "/api/v1/agents/initAgent", + json=_init_payload(agent_name=agent_name), + ) + assert create_resp.status_code == 200 + + set_authorizer(CreateOnlyAuthorizer()) + strict_resp = client.post( + "/api/v1/agents/initAgent", + json=_init_payload( + agent_name=agent_name, + steps=[ + { + "type": "tool", + "name": "new-tool", + "input_schema": {"type": "object"}, + "output_schema": {"type": "object"}, + } + ], + ), + ) + + assert strict_resp.status_code == 403 + + def test_init_agent_overwrite_warns_on_removed_referenced_evaluator(client: TestClient) -> None: # Given: an agent whose assigned policy contains a control referencing an agent evaluator. agent_name = f"agent-{uuid.uuid4().hex[:12]}" diff --git a/server/tests/test_observability_store_postgres.py b/server/tests/test_observability_store_postgres.py index 9c5942f0..41e99f87 100644 --- a/server/tests/test_observability_store_postgres.py +++ b/server/tests/test_observability_store_postgres.py @@ -175,6 +175,48 @@ async def test_postgres_event_store_scopes_queries_by_namespace() -> None: assert stats_a.stats[0].control_id == 1 +@pytest.mark.asyncio +async def test_postgres_event_store_idempotency_is_scoped_by_namespace() -> None: + session_maker = async_sessionmaker( + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + store = PostgresEventStore(session_maker) + + shared_execution_id = str(uuid4()) + agent_name = f"agent-{uuid4().hex[:12]}" + now = datetime.now(UTC) + event_a = _event( + control_execution_id=shared_execution_id, + agent_name=agent_name, + control_id=1, + action="observe", + matched=True, + timestamp=now, + trace_id="a" * 32, + ) + event_b = _event( + control_execution_id=shared_execution_id, + agent_name=agent_name, + control_id=2, + action="deny", + matched=True, + timestamp=now, + trace_id="b" * 32, + ) + + await store.store([event_a], namespace_key="tenant-a") + await store.store([event_b], namespace_key="tenant-b") + + query = EventQueryRequest(agent_name=agent_name, limit=10, offset=0) + events_a = await store.query_events(query, namespace_key="tenant-a") + events_b = await store.query_events(query, namespace_key="tenant-b") + + assert [event.control_id for event in events_a.events] == [1] + assert [event.control_id for event in events_b.events] == [2] + + @pytest.mark.asyncio async def test_postgres_event_store_store_empty_returns_zero() -> None: # Given: a Postgres-backed store From 8163931aeb74136c5d0e63d9ca14dbd3f4cf5b4c Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 15 May 2026 17:10:37 +0530 Subject: [PATCH 19/20] style(server): wrap upstream auth hint --- .../auth_framework/providers/http_upstream.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index 4223a5fd..b68972ed 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -339,7 +339,10 @@ def _ensure_target_context_matches_grant( raise ForbiddenError( error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, detail="Authorization grant is target-bound but the request target is unavailable.", - hint="Use an endpoint that includes target_type and target_id in the authorization context.", + hint=( + "Use an endpoint that includes target_type and target_id " + "in the authorization context." + ), ) expected_type = context.get("target_type") From fba49ef871c70407bad4d6a6a92130d69ab6a173 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 15 May 2026 17:33:28 +0530 Subject: [PATCH 20/20] docs(server): move auth contract to docs site --- docs/README.md | 2 +- docs/auth.md | 155 ------------------------------------------------- 2 files changed, 1 insertion(+), 156 deletions(-) delete mode 100644 docs/auth.md diff --git a/docs/README.md b/docs/README.md index e53dcf13..be50b19d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -10,7 +10,7 @@ This repository keeps documentation concise. The full documentation lives on the - [Controls](https://docs.agentcontrol.dev/concepts/controls) — Define and configure control rules - [Reference](https://docs.agentcontrol.dev/core/reference) — SDK and server API reference - [Configuration](https://docs.agentcontrol.dev/core/configuration) — Environment variables, auth, and database settings -- [Server auth contract](auth.md) - Pluggable auth modes, HTTP upstream contract, and runtime JWT claims +- [Authentication](https://docs.agentcontrol.dev/core/authentication) — Pluggable auth modes, HTTP upstream contract, and runtime JWT claims - [UI Quickstart](https://docs.agentcontrol.dev/core/ui-quickstart) — Run the dashboard and manage controls visually ## Examples diff --git a/docs/auth.md b/docs/auth.md deleted file mode 100644 index c738360b..00000000 --- a/docs/auth.md +++ /dev/null @@ -1,155 +0,0 @@ -# Server Auth Contract - -Agent Control keeps authentication and authorization provider-neutral. The server asks a configured provider whether a request may perform an operation, then scopes all data access with the returned `Principal`. - -## Operations - -Operations are stable strings. Deployers map them to their own permission model. - -```text -controls.read -controls.create -controls.update -controls.delete -policies.read -policies.create -policies.update -agents.read -agents.create -agents.update -evaluators.read -observability.read -observability.write -control_bindings.read -control_bindings.write -runtime.token_exchange -runtime.use -``` - -## Principal - -Providers return a generic principal. Agent Control treats `namespace_key`, `caller_id`, `target_type`, and `target_id` as opaque strings. - -```json -{ - "namespace_key": "tenant-a", - "is_admin": false, - "caller_id": "user-or-key-id", - "target_type": "session", - "target_id": "target-123", - "scopes": ["runtime.use"], - "expires_at": "2026-05-11T15:00:00Z" -} -``` - -`namespace_key` is the tenancy boundary. Server queries filter by it, and namespace-aware foreign keys prevent cross-namespace references. - -## Auth Modes - -Management auth is selected by `AGENT_CONTROL_AUTH_MODE`. - -| Mode | Meaning | -| --- | --- | -| `none` | No credentials required. Intended for local development only. | -| `api_key` | Validate caller credentials locally with `AGENT_CONTROL_API_KEYS` and/or `AGENT_CONTROL_ADMIN_API_KEYS`. Requires `AGENT_CONTROL_API_KEY_ENABLED=true`. `header` is accepted as a backwards-compatible alias. | -| `http_upstream` | POST each management authorization decision to `AGENT_CONTROL_AUTH_UPSTREAM_URL`. | - -When `AGENT_CONTROL_AUTH_MODE` is unset, startup selects `api_key` if local API-key validation is enabled and `none` otherwise. - -Runtime auth is selected by `AGENT_CONTROL_RUNTIME_AUTH_MODE`. - -| Mode | Meaning | -| --- | --- | -| unset | Use `jwt` when `AGENT_CONTROL_RUNTIME_TOKEN_SECRET` is set. Otherwise runtime requests fall through to management auth. | -| `none` | No runtime credentials required. Intended for local development only. | -| `api_key` | Validate runtime requests with the same local API-key mechanism. | -| `jwt` | Require target-bound runtime tokens minted by `/api/v1/auth/runtime-token-exchange`. | - -Common combinations: - -| Management | Runtime | Use case | -| --- | --- | --- | -| `api_key` | unset | Existing standalone deployments. | -| `api_key` | `jwt` | Local management keys with short-lived target-bound runtime tokens. This does not perform per-target authorization; any valid local API key can exchange for any target in the local namespace. | -| `http_upstream` | `jwt` | External identity or authorization service for management, local token verify for high-volume runtime calls. | -| `none` | `none` | Single-process local development. Do not use in production. | - -## HTTP Upstream Contract - -When `AGENT_CONTROL_AUTH_MODE=http_upstream`, the server sends: - -```http -POST {AGENT_CONTROL_AUTH_UPSTREAM_URL} -``` - -```json -{ - "operation": "control_bindings.write", - "context": { - "target_type": "session", - "target_id": "target-123" - } -} -``` - -The provider forwards inbound `X-API-Key`, `Authorization`, and `Cookie` headers. Add deployer-specific header names with `AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS`, for example: - -```text -AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS=Vendor-API-Key,X-Workspace-Id -``` - -If `AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN` is set, it is forwarded on `AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER` or `X-Agent-Control-Service-Token` by default. - -A successful upstream response is: - -```json -{ - "namespace_key": "tenant-a", - "is_admin": false, - "caller_id": "user-or-key-id", - "target_type": "session", - "target_id": "target-123", - "scopes": ["runtime.use"], - "expires_at": "2026-05-11T15:00:00Z" -} -``` - -Only `namespace_key` is always required. `target_type` and `target_id` must be returned together when present. `expires_at` must include timezone information. - -Status handling: - -| Upstream status | Agent Control result | -| --- | --- | -| `200` | Parse the principal grant. | -| `401` | Authentication error. | -| `403` | Forbidden error. | -| `404` | Not found error. | -| `429` | `503` with a rate-limit detail and `Retry-After` hint when present. | -| Other statuses or upstream network errors | Fail closed with `503`. | -| Malformed `200` principal response | Fail closed with `502`. | -| `200` target grant that conflicts with request context | Fail closed with `403`. | - -## Runtime JWT Claims - -`/api/v1/auth/runtime-token-exchange` is a management-style request. The configured management provider authorizes `runtime.token_exchange` for the requested target. Agent Control then mints its own HS256 JWT with `AGENT_CONTROL_RUNTIME_TOKEN_SECRET`. - -The token payload contains: - -```json -{ - "iss": "agent-control/server", - "domain": "runtime", - "namespace_key": "tenant-a", - "actor_id": "user-or-key-id", - "target_type": "session", - "target_id": "target-123", - "scopes": ["runtime.use"], - "iat": 1778509800, - "exp": 1778510100, - "jti": "opaque-token-id" -} -``` - -Verification requires the expected issuer, `domain="runtime"`, a valid signature, an unexpired `exp`, and `runtime.use` in `scopes`. The token is accepted only for requests whose `target_type` and `target_id` match the bound target. - -The expiry is the earlier of `AGENT_CONTROL_RUNTIME_TOKEN_TTL_SECONDS` and the upstream grant's `expires_at` when supplied. Runtime token TTLs are capped at 86400 seconds.