Skip to content

Commit f295ffd

Browse files
committed
feat: implement inbound auth (#467)
* feat: extend JWT wizard with allowedScopes and agent OAuth credential inputs * feat: auto-create managed OAuth credential for CUSTOM_JWT gateway * feat: add CLI flags for CUSTOM_JWT agent OAuth credentials * feat: add CUSTOM_JWT Bearer token auth to agent templates (Strands, LangChain, OpenAI, Google ADK) * feat: protect managed credentials from accidental deletion * test: add tests for CUSTOM_JWT CLI validation and managed credential protection * fix: resolve httpx import collision between AWS_IAM and CUSTOM_JWT templates * fix: use placeholder instead of initialValue for gateway discovery URL * feat: wire CUSTOM_JWT inbound auth through AgentCore identity system
1 parent 3ef1d1f commit f295ffd

22 files changed

Lines changed: 459 additions & 19 deletions

File tree

src/assets/__tests__/__snapshots__/assets.snapshot.test.ts.snap

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,23 @@ logger = logging.getLogger(__name__)
17491749
import httpx
17501750
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
17511751
{{/if}}
1752+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
1753+
from bedrock_agentcore.identity import requires_access_token
1754+
{{/if}}
1755+
1756+
{{#each gatewayProviders}}
1757+
{{#if (eq authType "CUSTOM_JWT")}}
1758+
@requires_access_token(
1759+
provider_name="{{credentialProviderName}}",
1760+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
1761+
auth_flow="M2M",
1762+
)
1763+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
1764+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
1765+
return access_token
1766+
1767+
{{/if}}
1768+
{{/each}}
17521769
17531770
def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
17541771
"""Returns MCP Toolsets for all configured gateways."""
@@ -1763,6 +1780,10 @@ def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
17631780
url=url,
17641781
httpx_client_factory=lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)
17651782
)))
1783+
{{else if (eq authType "CUSTOM_JWT")}}
1784+
token = _get_bearer_token_{{snakeCase name}}()
1785+
headers = {"Authorization": f"Bearer {token}"} if token else None
1786+
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url, headers=headers)))
17661787
{{else}}
17671788
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url)))
17681789
{{/if}}
@@ -2035,6 +2056,23 @@ logger = logging.getLogger(__name__)
20352056
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
20362057
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
20372058
{{/if}}
2059+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
2060+
from bedrock_agentcore.identity import requires_access_token
2061+
{{/if}}
2062+
2063+
{{#each gatewayProviders}}
2064+
{{#if (eq authType "CUSTOM_JWT")}}
2065+
@requires_access_token(
2066+
provider_name="{{credentialProviderName}}",
2067+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
2068+
auth_flow="M2M",
2069+
)
2070+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
2071+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
2072+
return access_token
2073+
2074+
{{/if}}
2075+
{{/each}}
20382076
20392077
def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
20402078
"""Returns an MCP Client connected to all configured gateways."""
@@ -2046,6 +2084,10 @@ def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
20462084
session = create_aws_session()
20472085
auth = SigV4HTTPXAuth(session.get_credentials(), "bedrock-agentcore", session.region_name)
20482086
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "auth": auth}
2087+
{{else if (eq authType "CUSTOM_JWT")}}
2088+
token = _get_bearer_token_{{snakeCase name}}()
2089+
headers = {"Authorization": f"Bearer {token}"} if token else None
2090+
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "headers": headers}
20492091
{{else}}
20502092
servers["{{name}}"] = {"transport": "streamable_http", "url": url}
20512093
{{/if}}
@@ -2460,6 +2502,23 @@ logger = logging.getLogger(__name__)
24602502
import httpx
24612503
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
24622504
{{/if}}
2505+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
2506+
from bedrock_agentcore.identity import requires_access_token
2507+
{{/if}}
2508+
2509+
{{#each gatewayProviders}}
2510+
{{#if (eq authType "CUSTOM_JWT")}}
2511+
@requires_access_token(
2512+
provider_name="{{credentialProviderName}}",
2513+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
2514+
auth_flow="M2M",
2515+
)
2516+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
2517+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
2518+
return access_token
2519+
2520+
{{/if}}
2521+
{{/each}}
24632522
24642523
def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
24652524
"""Returns MCP servers for all configured gateways."""
@@ -2474,6 +2533,10 @@ def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
24742533
name="{{name}}",
24752534
params={"url": url, "httpx_client_factory": lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)}
24762535
))
2536+
{{else if (eq authType "CUSTOM_JWT")}}
2537+
token = _get_bearer_token_{{snakeCase name}}()
2538+
headers = {"Authorization": f"Bearer {token}"} if token else {}
2539+
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url, "headers": headers}))
24772540
{{else}}
24782541
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url}))
24792542
{{/if}}
@@ -2771,7 +2834,23 @@ logger = logging.getLogger(__name__)
27712834
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
27722835
from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client
27732836
{{/if}}
2837+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
2838+
from bedrock_agentcore.identity import requires_access_token
2839+
{{/if}}
2840+
2841+
{{#each gatewayProviders}}
2842+
{{#if (eq authType "CUSTOM_JWT")}}
2843+
@requires_access_token(
2844+
provider_name="{{credentialProviderName}}",
2845+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
2846+
auth_flow="M2M",
2847+
)
2848+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
2849+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
2850+
return access_token
27742851
2852+
{{/if}}
2853+
{{/each}}
27752854
{{#each gatewayProviders}}
27762855
def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
27772856
"""Returns an MCP Client connected to the {{name}} gateway."""
@@ -2781,6 +2860,10 @@ def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
27812860
return None
27822861
{{#if (eq authType "AWS_IAM")}}
27832862
return MCPClient(lambda: aws_iam_streamablehttp_client(url, aws_service="bedrock-agentcore", aws_region=os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))))
2863+
{{else if (eq authType "CUSTOM_JWT")}}
2864+
token = _get_bearer_token_{{snakeCase name}}()
2865+
headers = {"Authorization": f"Bearer {token}"} if token else {}
2866+
return MCPClient(lambda: streamablehttp_client(url, headers=headers))
27842867
{{else}}
27852868
return MCPClient(lambda: streamablehttp_client(url))
27862869
{{/if}}

src/assets/python/googleadk/base/mcp_client/client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,23 @@
1010
import httpx
1111
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
1212
{{/if}}
13+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
14+
from bedrock_agentcore.identity import requires_access_token
15+
{{/if}}
16+
17+
{{#each gatewayProviders}}
18+
{{#if (eq authType "CUSTOM_JWT")}}
19+
@requires_access_token(
20+
provider_name="{{credentialProviderName}}",
21+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
22+
auth_flow="M2M",
23+
)
24+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
25+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
26+
return access_token
27+
28+
{{/if}}
29+
{{/each}}
1330

1431
def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
1532
"""Returns MCP Toolsets for all configured gateways."""
@@ -24,6 +41,10 @@ def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
2441
url=url,
2542
httpx_client_factory=lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)
2643
)))
44+
{{else if (eq authType "CUSTOM_JWT")}}
45+
token = _get_bearer_token_{{snakeCase name}}()
46+
headers = {"Authorization": f"Bearer {token}"} if token else None
47+
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url, headers=headers)))
2748
{{else}}
2849
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url)))
2950
{{/if}}

src/assets/python/langchain_langgraph/base/mcp_client/client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,23 @@
88
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
99
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
1010
{{/if}}
11+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
12+
from bedrock_agentcore.identity import requires_access_token
13+
{{/if}}
14+
15+
{{#each gatewayProviders}}
16+
{{#if (eq authType "CUSTOM_JWT")}}
17+
@requires_access_token(
18+
provider_name="{{credentialProviderName}}",
19+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
20+
auth_flow="M2M",
21+
)
22+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
23+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
24+
return access_token
25+
26+
{{/if}}
27+
{{/each}}
1128

1229
def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
1330
"""Returns an MCP Client connected to all configured gateways."""
@@ -19,6 +36,10 @@ def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
1936
session = create_aws_session()
2037
auth = SigV4HTTPXAuth(session.get_credentials(), "bedrock-agentcore", session.region_name)
2138
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "auth": auth}
39+
{{else if (eq authType "CUSTOM_JWT")}}
40+
token = _get_bearer_token_{{snakeCase name}}()
41+
headers = {"Authorization": f"Bearer {token}"} if token else None
42+
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "headers": headers}
2243
{{else}}
2344
servers["{{name}}"] = {"transport": "streamable_http", "url": url}
2445
{{/if}}

src/assets/python/openaiagents/base/mcp_client/client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,23 @@
99
import httpx
1010
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
1111
{{/if}}
12+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
13+
from bedrock_agentcore.identity import requires_access_token
14+
{{/if}}
15+
16+
{{#each gatewayProviders}}
17+
{{#if (eq authType "CUSTOM_JWT")}}
18+
@requires_access_token(
19+
provider_name="{{credentialProviderName}}",
20+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
21+
auth_flow="M2M",
22+
)
23+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
24+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
25+
return access_token
26+
27+
{{/if}}
28+
{{/each}}
1229

1330
def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
1431
"""Returns MCP servers for all configured gateways."""
@@ -23,6 +40,10 @@ def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
2340
name="{{name}}",
2441
params={"url": url, "httpx_client_factory": lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)}
2542
))
43+
{{else if (eq authType "CUSTOM_JWT")}}
44+
token = _get_bearer_token_{{snakeCase name}}()
45+
headers = {"Authorization": f"Bearer {token}"} if token else {}
46+
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url, "headers": headers}))
2647
{{else}}
2748
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url}))
2849
{{/if}}

src/assets/python/strands/base/mcp_client/client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,23 @@
99
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
1010
from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client
1111
{{/if}}
12+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
13+
from bedrock_agentcore.identity import requires_access_token
14+
{{/if}}
15+
16+
{{#each gatewayProviders}}
17+
{{#if (eq authType "CUSTOM_JWT")}}
18+
@requires_access_token(
19+
provider_name="{{credentialProviderName}}",
20+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
21+
auth_flow="M2M",
22+
)
23+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
24+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
25+
return access_token
1226

27+
{{/if}}
28+
{{/each}}
1329
{{#each gatewayProviders}}
1430
def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
1531
"""Returns an MCP Client connected to the {{name}} gateway."""
@@ -19,6 +35,10 @@ def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
1935
return None
2036
{{#if (eq authType "AWS_IAM")}}
2137
return MCPClient(lambda: aws_iam_streamablehttp_client(url, aws_service="bedrock-agentcore", aws_region=os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))))
38+
{{else if (eq authType "CUSTOM_JWT")}}
39+
token = _get_bearer_token_{{snakeCase name}}()
40+
headers = {"Authorization": f"Bearer {token}"} if token else {}
41+
return MCPClient(lambda: streamablehttp_client(url, headers=headers))
2242
{{else}}
2343
return MCPClient(lambda: streamablehttp_client(url))
2444
{{/if}}

src/cli/commands/add/__tests__/validate.test.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,47 @@ describe('validate', () => {
240240
expect(validateAddGatewayOptions(validGatewayOptionsNone)).toEqual({ valid: true });
241241
expect(validateAddGatewayOptions(validGatewayOptionsJwt)).toEqual({ valid: true });
242242
});
243+
244+
// AC15: agentClientId and agentClientSecret must be provided together
245+
it('returns error when agentClientId provided without agentClientSecret', () => {
246+
const result = validateAddGatewayOptions({
247+
...validGatewayOptionsJwt,
248+
agentClientId: 'my-client-id',
249+
});
250+
expect(result.valid).toBe(false);
251+
expect(result.error).toBe('Both --agent-client-id and --agent-client-secret must be provided together');
252+
});
253+
254+
it('returns error when agentClientSecret provided without agentClientId', () => {
255+
const result = validateAddGatewayOptions({
256+
...validGatewayOptionsJwt,
257+
agentClientSecret: 'my-secret',
258+
});
259+
expect(result.valid).toBe(false);
260+
expect(result.error).toBe('Both --agent-client-id and --agent-client-secret must be provided together');
261+
});
262+
263+
// AC16: agent credentials only valid with CUSTOM_JWT
264+
it('returns error when agent credentials used with non-CUSTOM_JWT authorizer', () => {
265+
const result = validateAddGatewayOptions({
266+
...validGatewayOptionsNone,
267+
agentClientId: 'my-client-id',
268+
agentClientSecret: 'my-secret',
269+
});
270+
expect(result.valid).toBe(false);
271+
expect(result.error).toBe('Agent OAuth credentials are only valid with CUSTOM_JWT authorizer');
272+
});
273+
274+
// AC17: valid CUSTOM_JWT with agent credentials passes
275+
it('passes for CUSTOM_JWT with agent credentials', () => {
276+
const result = validateAddGatewayOptions({
277+
...validGatewayOptionsJwt,
278+
agentClientId: 'my-client-id',
279+
agentClientSecret: 'my-secret',
280+
allowedScopes: 'scope1,scope2',
281+
});
282+
expect(result.valid).toBe(true);
283+
});
243284
});
244285

245286
describe('validateAddGatewayTargetOptions', () => {

src/cli/commands/add/actions.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ export interface ValidatedAddGatewayOptions {
6464
discoveryUrl?: string;
6565
allowedAudience?: string;
6666
allowedClients?: string;
67+
allowedScopes?: string;
68+
agentClientId?: string;
69+
agentClientSecret?: string;
6770
agents?: string;
6871
}
6972

@@ -267,6 +270,14 @@ function buildGatewayConfig(options: ValidatedAddGatewayOptions): AddGatewayConf
267270
.allowedClients!.split(',')
268271
.map(s => s.trim())
269272
.filter(Boolean),
273+
allowedScopes: options.allowedScopes
274+
? options.allowedScopes
275+
.split(',')
276+
.map(s => s.trim())
277+
.filter(Boolean)
278+
: undefined,
279+
agentClientId: options.agentClientId,
280+
agentClientSecret: options.agentClientSecret,
270281
};
271282
}
272283

src/cli/commands/add/command.tsx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ async function handleAddGatewayCLI(options: AddGatewayOptions): Promise<void> {
8282
discoveryUrl: options.discoveryUrl,
8383
allowedAudience: options.allowedAudience,
8484
allowedClients: options.allowedClients,
85+
allowedScopes: options.allowedScopes,
86+
agentClientId: options.agentClientId,
87+
agentClientSecret: options.agentClientSecret,
8588
agents: options.agents,
8689
});
8790

@@ -272,6 +275,9 @@ export function registerAdd(program: Command) {
272275
.option('--discovery-url <url>', 'OIDC discovery URL (required for CUSTOM_JWT)')
273276
.option('--allowed-audience <values>', 'Comma-separated allowed audience values (required for CUSTOM_JWT)')
274277
.option('--allowed-clients <values>', 'Comma-separated allowed client IDs (required for CUSTOM_JWT)')
278+
.option('--allowed-scopes <scopes>', 'Comma-separated allowed scopes (optional for CUSTOM_JWT)')
279+
.option('--agent-client-id <id>', 'Agent OAuth client ID for Bearer token auth (CUSTOM_JWT)')
280+
.option('--agent-client-secret <secret>', 'Agent OAuth client secret (CUSTOM_JWT)')
275281
.option('--json', 'Output as JSON')
276282
.action(async options => {
277283
requireProject();

src/cli/commands/add/types.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ export interface AddGatewayOptions {
3131
discoveryUrl?: string;
3232
allowedAudience?: string;
3333
allowedClients?: string;
34+
allowedScopes?: string;
35+
agentClientId?: string;
36+
agentClientSecret?: string;
3437
agents?: string;
3538
json?: boolean;
3639
}

0 commit comments

Comments
 (0)