|
29 | 29 | SetPolicyResponse, |
30 | 30 | StepKey, |
31 | 31 | ) |
32 | | -from fastapi import APIRouter, Depends, Query |
| 32 | +from fastapi import APIRouter, Depends, Query, Request |
33 | 33 | from jsonschema_rs import ValidationError as JSONSchemaValidationError |
34 | 34 | from pydantic import BaseModel, ValidationError |
35 | 35 | from sqlalchemy import delete, func, select |
36 | 36 | from sqlalchemy.dialects.postgresql import insert as pg_insert |
37 | 37 | from sqlalchemy.ext.asyncio import AsyncSession |
38 | 38 |
|
39 | | -from ..auth_framework import Operation, Principal, require_operation |
| 39 | +from ..auth_framework import Operation, Principal, get_authorizer, require_operation |
40 | 40 | from ..db import get_async_db |
41 | 41 | from ..errors import ( |
42 | 42 | APIValidationError, |
43 | 43 | BadRequestError, |
44 | 44 | ConflictError, |
45 | 45 | DatabaseError, |
| 46 | + ForbiddenError, |
46 | 47 | NotFoundError, |
47 | 48 | ) |
48 | 49 | from ..logging_utils import get_logger |
|
85 | 86 | type StepKeyTuple = tuple[str, str] |
86 | 87 |
|
87 | 88 |
|
| 89 | +def _complete_target_context( |
| 90 | + target_type: object | None, |
| 91 | + target_id: object | None, |
| 92 | +) -> dict[str, str] | None: |
| 93 | + """Return target context only when both halves are present strings.""" |
| 94 | + if not isinstance(target_type, str) or not isinstance(target_id, str): |
| 95 | + return None |
| 96 | + if not target_type or not target_id: |
| 97 | + return None |
| 98 | + return {"target_type": target_type, "target_id": target_id} |
| 99 | + |
| 100 | + |
| 101 | +async def _init_agent_target_context(request: Request) -> dict[str, str] | None: |
| 102 | + """Extract optional target context from an ``initAgent`` body.""" |
| 103 | + try: |
| 104 | + body = await request.json() |
| 105 | + except Exception: # noqa: BLE001 malformed JSON, defer to endpoint validation |
| 106 | + return None |
| 107 | + if not isinstance(body, dict): |
| 108 | + return None |
| 109 | + return _complete_target_context(body.get("target_type"), body.get("target_id")) |
| 110 | + |
| 111 | + |
| 112 | +def _agent_controls_target_context(request: Request) -> dict[str, str] | None: |
| 113 | + """Extract optional target context from ``GET /agents/{name}/controls``.""" |
| 114 | + return _complete_target_context( |
| 115 | + request.query_params.get("target_type"), |
| 116 | + request.query_params.get("target_id"), |
| 117 | + ) |
| 118 | + |
| 119 | + |
| 120 | +async def _authorize_target_read_if_present( |
| 121 | + request: Request, |
| 122 | + context: dict[str, str] | None, |
| 123 | +) -> Principal | None: |
| 124 | + """Require target read authorization before returning target-merged controls.""" |
| 125 | + if context is None: |
| 126 | + return None |
| 127 | + return await get_authorizer(Operation.CONTROL_BINDINGS_READ).authorize( |
| 128 | + request, |
| 129 | + Operation.CONTROL_BINDINGS_READ, |
| 130 | + context, |
| 131 | + ) |
| 132 | + |
| 133 | + |
| 134 | +async def _init_agent_target_principal(request: Request) -> Principal | None: |
| 135 | + return await _authorize_target_read_if_present( |
| 136 | + request, |
| 137 | + await _init_agent_target_context(request), |
| 138 | + ) |
| 139 | + |
| 140 | + |
| 141 | +async def _agent_controls_target_principal(request: Request) -> Principal | None: |
| 142 | + return await _authorize_target_read_if_present( |
| 143 | + request, |
| 144 | + _agent_controls_target_context(request), |
| 145 | + ) |
| 146 | + |
| 147 | + |
| 148 | +def _ensure_target_principal_matches_namespace( |
| 149 | + principal: Principal, |
| 150 | + target_principal: Principal | None, |
| 151 | +) -> None: |
| 152 | + """Fail closed if the target authorization resolves to a different namespace.""" |
| 153 | + if target_principal is None: |
| 154 | + return |
| 155 | + if target_principal.namespace_key == principal.namespace_key: |
| 156 | + return |
| 157 | + raise ForbiddenError( |
| 158 | + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, |
| 159 | + detail="Target authorization resolved to a different namespace.", |
| 160 | + hint="Ensure the credential is scoped to the requested target and namespace.", |
| 161 | + ) |
| 162 | + |
| 163 | + |
88 | 164 | # ============================================================================= |
89 | 165 | # List Agents Models |
90 | 166 | # ============================================================================= |
@@ -445,6 +521,7 @@ async def init_agent( |
445 | 521 | request: InitAgentRequest, |
446 | 522 | db: AsyncSession = Depends(get_async_db), |
447 | 523 | principal: Principal = Depends(require_operation(Operation.AGENTS_CREATE)), |
| 524 | + target_principal: Principal | None = Depends(_init_agent_target_principal), |
448 | 525 | ) -> InitAgentResponse: |
449 | 526 | """ |
450 | 527 | Register a new agent or update an existing agent's steps and metadata. |
@@ -474,6 +551,7 @@ async def init_agent( |
474 | 551 | InitAgentResponse with created flag and the effective controls |
475 | 552 | """ |
476 | 553 | namespace_key = principal.namespace_key |
| 554 | + _ensure_target_principal_matches_namespace(principal, target_principal) |
477 | 555 |
|
478 | 556 | # Check for evaluator name collisions with built-in evaluators |
479 | 557 | builtin_names = _get_builtin_evaluator_names() |
@@ -1493,6 +1571,7 @@ async def list_agent_controls( |
1493 | 1571 | ), |
1494 | 1572 | db: AsyncSession = Depends(get_async_db), |
1495 | 1573 | principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), |
| 1574 | + target_principal: Principal | None = Depends(_agent_controls_target_principal), |
1496 | 1575 | ) -> AgentControlsResponse: |
1497 | 1576 | """ |
1498 | 1577 | List protection controls effective for an agent. |
@@ -1527,6 +1606,7 @@ async def list_agent_controls( |
1527 | 1606 | HTTPException 404: Agent not found |
1528 | 1607 | """ |
1529 | 1608 | namespace_key = principal.namespace_key |
| 1609 | + _ensure_target_principal_matches_namespace(principal, target_principal) |
1530 | 1610 |
|
1531 | 1611 | if (target_type is None) != (target_id is None): |
1532 | 1612 | raise BadRequestError( |
|
0 commit comments