Skip to content

Commit fb63493

Browse files
authored
feat: implement inbound auth (#467)
* feat: extend JWT wizard with allowedScopes and agent OAuth credential inputs * feat: auto-create managed OAuth credential for CUSTOM_JWT gateway * feat: add CLI flags for CUSTOM_JWT agent OAuth credentials * feat: add CUSTOM_JWT Bearer token auth to agent templates (Strands, LangChain, OpenAI, Google ADK) * feat: protect managed credentials from accidental deletion * test: add tests for CUSTOM_JWT CLI validation and managed credential protection * fix: resolve httpx import collision between AWS_IAM and CUSTOM_JWT templates * fix: use placeholder instead of initialValue for gateway discovery URL * feat: wire CUSTOM_JWT inbound auth through AgentCore identity system
1 parent 61ef4bc commit fb63493

22 files changed

Lines changed: 459 additions & 19 deletions

File tree

src/assets/__tests__/__snapshots__/assets.snapshot.test.ts.snap

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1726,6 +1726,23 @@ logger = logging.getLogger(__name__)
17261726
import httpx
17271727
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
17281728
{{/if}}
1729+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
1730+
from bedrock_agentcore.identity import requires_access_token
1731+
{{/if}}
1732+
1733+
{{#each gatewayProviders}}
1734+
{{#if (eq authType "CUSTOM_JWT")}}
1735+
@requires_access_token(
1736+
provider_name="{{credentialProviderName}}",
1737+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
1738+
auth_flow="M2M",
1739+
)
1740+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
1741+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
1742+
return access_token
1743+
1744+
{{/if}}
1745+
{{/each}}
17291746
17301747
def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
17311748
"""Returns MCP Toolsets for all configured gateways."""
@@ -1740,6 +1757,10 @@ def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
17401757
url=url,
17411758
httpx_client_factory=lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)
17421759
)))
1760+
{{else if (eq authType "CUSTOM_JWT")}}
1761+
token = _get_bearer_token_{{snakeCase name}}()
1762+
headers = {"Authorization": f"Bearer {token}"} if token else None
1763+
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url, headers=headers)))
17431764
{{else}}
17441765
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url)))
17451766
{{/if}}
@@ -2012,6 +2033,23 @@ logger = logging.getLogger(__name__)
20122033
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
20132034
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
20142035
{{/if}}
2036+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
2037+
from bedrock_agentcore.identity import requires_access_token
2038+
{{/if}}
2039+
2040+
{{#each gatewayProviders}}
2041+
{{#if (eq authType "CUSTOM_JWT")}}
2042+
@requires_access_token(
2043+
provider_name="{{credentialProviderName}}",
2044+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
2045+
auth_flow="M2M",
2046+
)
2047+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
2048+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
2049+
return access_token
2050+
2051+
{{/if}}
2052+
{{/each}}
20152053
20162054
def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
20172055
"""Returns an MCP Client connected to all configured gateways."""
@@ -2023,6 +2061,10 @@ def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
20232061
session = create_aws_session()
20242062
auth = SigV4HTTPXAuth(session.get_credentials(), "bedrock-agentcore", session.region_name)
20252063
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "auth": auth}
2064+
{{else if (eq authType "CUSTOM_JWT")}}
2065+
token = _get_bearer_token_{{snakeCase name}}()
2066+
headers = {"Authorization": f"Bearer {token}"} if token else None
2067+
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "headers": headers}
20262068
{{else}}
20272069
servers["{{name}}"] = {"transport": "streamable_http", "url": url}
20282070
{{/if}}
@@ -2438,6 +2480,23 @@ logger = logging.getLogger(__name__)
24382480
import httpx
24392481
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
24402482
{{/if}}
2483+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
2484+
from bedrock_agentcore.identity import requires_access_token
2485+
{{/if}}
2486+
2487+
{{#each gatewayProviders}}
2488+
{{#if (eq authType "CUSTOM_JWT")}}
2489+
@requires_access_token(
2490+
provider_name="{{credentialProviderName}}",
2491+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
2492+
auth_flow="M2M",
2493+
)
2494+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
2495+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
2496+
return access_token
2497+
2498+
{{/if}}
2499+
{{/each}}
24412500
24422501
def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
24432502
"""Returns MCP servers for all configured gateways."""
@@ -2452,6 +2511,10 @@ def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
24522511
name="{{name}}",
24532512
params={"url": url, "httpx_client_factory": lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)}
24542513
))
2514+
{{else if (eq authType "CUSTOM_JWT")}}
2515+
token = _get_bearer_token_{{snakeCase name}}()
2516+
headers = {"Authorization": f"Bearer {token}"} if token else {}
2517+
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url, "headers": headers}))
24552518
{{else}}
24562519
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url}))
24572520
{{/if}}
@@ -2749,7 +2812,23 @@ logger = logging.getLogger(__name__)
27492812
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
27502813
from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client
27512814
{{/if}}
2815+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
2816+
from bedrock_agentcore.identity import requires_access_token
2817+
{{/if}}
2818+
2819+
{{#each gatewayProviders}}
2820+
{{#if (eq authType "CUSTOM_JWT")}}
2821+
@requires_access_token(
2822+
provider_name="{{credentialProviderName}}",
2823+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
2824+
auth_flow="M2M",
2825+
)
2826+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
2827+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
2828+
return access_token
27522829
2830+
{{/if}}
2831+
{{/each}}
27532832
{{#each gatewayProviders}}
27542833
def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
27552834
"""Returns an MCP Client connected to the {{name}} gateway."""
@@ -2759,6 +2838,10 @@ def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
27592838
return None
27602839
{{#if (eq authType "AWS_IAM")}}
27612840
return MCPClient(lambda: aws_iam_streamablehttp_client(url, aws_service="bedrock-agentcore", aws_region=os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))))
2841+
{{else if (eq authType "CUSTOM_JWT")}}
2842+
token = _get_bearer_token_{{snakeCase name}}()
2843+
headers = {"Authorization": f"Bearer {token}"} if token else {}
2844+
return MCPClient(lambda: streamablehttp_client(url, headers=headers))
27622845
{{else}}
27632846
return MCPClient(lambda: streamablehttp_client(url))
27642847
{{/if}}

src/assets/python/googleadk/base/mcp_client/client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,23 @@
1010
import httpx
1111
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
1212
{{/if}}
13+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
14+
from bedrock_agentcore.identity import requires_access_token
15+
{{/if}}
16+
17+
{{#each gatewayProviders}}
18+
{{#if (eq authType "CUSTOM_JWT")}}
19+
@requires_access_token(
20+
provider_name="{{credentialProviderName}}",
21+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
22+
auth_flow="M2M",
23+
)
24+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
25+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
26+
return access_token
27+
28+
{{/if}}
29+
{{/each}}
1330

1431
def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
1532
"""Returns MCP Toolsets for all configured gateways."""
@@ -24,6 +41,10 @@ def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
2441
url=url,
2542
httpx_client_factory=lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)
2643
)))
44+
{{else if (eq authType "CUSTOM_JWT")}}
45+
token = _get_bearer_token_{{snakeCase name}}()
46+
headers = {"Authorization": f"Bearer {token}"} if token else None
47+
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url, headers=headers)))
2748
{{else}}
2849
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url)))
2950
{{/if}}

src/assets/python/langchain_langgraph/base/mcp_client/client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,23 @@
88
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
99
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
1010
{{/if}}
11+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
12+
from bedrock_agentcore.identity import requires_access_token
13+
{{/if}}
14+
15+
{{#each gatewayProviders}}
16+
{{#if (eq authType "CUSTOM_JWT")}}
17+
@requires_access_token(
18+
provider_name="{{credentialProviderName}}",
19+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
20+
auth_flow="M2M",
21+
)
22+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
23+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
24+
return access_token
25+
26+
{{/if}}
27+
{{/each}}
1128

1229
def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
1330
"""Returns an MCP Client connected to all configured gateways."""
@@ -19,6 +36,10 @@ def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
1936
session = create_aws_session()
2037
auth = SigV4HTTPXAuth(session.get_credentials(), "bedrock-agentcore", session.region_name)
2138
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "auth": auth}
39+
{{else if (eq authType "CUSTOM_JWT")}}
40+
token = _get_bearer_token_{{snakeCase name}}()
41+
headers = {"Authorization": f"Bearer {token}"} if token else None
42+
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "headers": headers}
2243
{{else}}
2344
servers["{{name}}"] = {"transport": "streamable_http", "url": url}
2445
{{/if}}

src/assets/python/openaiagents/base/mcp_client/client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,23 @@
99
import httpx
1010
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
1111
{{/if}}
12+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
13+
from bedrock_agentcore.identity import requires_access_token
14+
{{/if}}
15+
16+
{{#each gatewayProviders}}
17+
{{#if (eq authType "CUSTOM_JWT")}}
18+
@requires_access_token(
19+
provider_name="{{credentialProviderName}}",
20+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
21+
auth_flow="M2M",
22+
)
23+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
24+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
25+
return access_token
26+
27+
{{/if}}
28+
{{/each}}
1229

1330
def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
1431
"""Returns MCP servers for all configured gateways."""
@@ -23,6 +40,10 @@ def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
2340
name="{{name}}",
2441
params={"url": url, "httpx_client_factory": lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)}
2542
))
43+
{{else if (eq authType "CUSTOM_JWT")}}
44+
token = _get_bearer_token_{{snakeCase name}}()
45+
headers = {"Authorization": f"Bearer {token}"} if token else {}
46+
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url, "headers": headers}))
2647
{{else}}
2748
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url}))
2849
{{/if}}

src/assets/python/strands/base/mcp_client/client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,23 @@
99
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
1010
from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client
1111
{{/if}}
12+
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
13+
from bedrock_agentcore.identity import requires_access_token
14+
{{/if}}
15+
16+
{{#each gatewayProviders}}
17+
{{#if (eq authType "CUSTOM_JWT")}}
18+
@requires_access_token(
19+
provider_name="{{credentialProviderName}}",
20+
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
21+
auth_flow="M2M",
22+
)
23+
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
24+
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
25+
return access_token
1226

27+
{{/if}}
28+
{{/each}}
1329
{{#each gatewayProviders}}
1430
def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
1531
"""Returns an MCP Client connected to the {{name}} gateway."""
@@ -19,6 +35,10 @@ def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
1935
return None
2036
{{#if (eq authType "AWS_IAM")}}
2137
return MCPClient(lambda: aws_iam_streamablehttp_client(url, aws_service="bedrock-agentcore", aws_region=os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))))
38+
{{else if (eq authType "CUSTOM_JWT")}}
39+
token = _get_bearer_token_{{snakeCase name}}()
40+
headers = {"Authorization": f"Bearer {token}"} if token else {}
41+
return MCPClient(lambda: streamablehttp_client(url, headers=headers))
2242
{{else}}
2343
return MCPClient(lambda: streamablehttp_client(url))
2444
{{/if}}

src/cli/commands/add/__tests__/validate.test.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,47 @@ describe('validate', () => {
240240
expect(validateAddGatewayOptions(validGatewayOptionsNone)).toEqual({ valid: true });
241241
expect(validateAddGatewayOptions(validGatewayOptionsJwt)).toEqual({ valid: true });
242242
});
243+
244+
// AC15: agentClientId and agentClientSecret must be provided together
245+
it('returns error when agentClientId provided without agentClientSecret', () => {
246+
const result = validateAddGatewayOptions({
247+
...validGatewayOptionsJwt,
248+
agentClientId: 'my-client-id',
249+
});
250+
expect(result.valid).toBe(false);
251+
expect(result.error).toBe('Both --agent-client-id and --agent-client-secret must be provided together');
252+
});
253+
254+
it('returns error when agentClientSecret provided without agentClientId', () => {
255+
const result = validateAddGatewayOptions({
256+
...validGatewayOptionsJwt,
257+
agentClientSecret: 'my-secret',
258+
});
259+
expect(result.valid).toBe(false);
260+
expect(result.error).toBe('Both --agent-client-id and --agent-client-secret must be provided together');
261+
});
262+
263+
// AC16: agent credentials only valid with CUSTOM_JWT
264+
it('returns error when agent credentials used with non-CUSTOM_JWT authorizer', () => {
265+
const result = validateAddGatewayOptions({
266+
...validGatewayOptionsNone,
267+
agentClientId: 'my-client-id',
268+
agentClientSecret: 'my-secret',
269+
});
270+
expect(result.valid).toBe(false);
271+
expect(result.error).toBe('Agent OAuth credentials are only valid with CUSTOM_JWT authorizer');
272+
});
273+
274+
// AC17: valid CUSTOM_JWT with agent credentials passes
275+
it('passes for CUSTOM_JWT with agent credentials', () => {
276+
const result = validateAddGatewayOptions({
277+
...validGatewayOptionsJwt,
278+
agentClientId: 'my-client-id',
279+
agentClientSecret: 'my-secret',
280+
allowedScopes: 'scope1,scope2',
281+
});
282+
expect(result.valid).toBe(true);
283+
});
243284
});
244285

245286
describe('validateAddGatewayTargetOptions', () => {

src/cli/commands/add/actions.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ export interface ValidatedAddGatewayOptions {
6464
discoveryUrl?: string;
6565
allowedAudience?: string;
6666
allowedClients?: string;
67+
allowedScopes?: string;
68+
agentClientId?: string;
69+
agentClientSecret?: string;
6770
agents?: string;
6871
}
6972

@@ -267,6 +270,14 @@ function buildGatewayConfig(options: ValidatedAddGatewayOptions): AddGatewayConf
267270
.allowedClients!.split(',')
268271
.map(s => s.trim())
269272
.filter(Boolean),
273+
allowedScopes: options.allowedScopes
274+
? options.allowedScopes
275+
.split(',')
276+
.map(s => s.trim())
277+
.filter(Boolean)
278+
: undefined,
279+
agentClientId: options.agentClientId,
280+
agentClientSecret: options.agentClientSecret,
270281
};
271282
}
272283

src/cli/commands/add/command.tsx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ async function handleAddGatewayCLI(options: AddGatewayOptions): Promise<void> {
8282
discoveryUrl: options.discoveryUrl,
8383
allowedAudience: options.allowedAudience,
8484
allowedClients: options.allowedClients,
85+
allowedScopes: options.allowedScopes,
86+
agentClientId: options.agentClientId,
87+
agentClientSecret: options.agentClientSecret,
8588
agents: options.agents,
8689
});
8790

@@ -272,6 +275,9 @@ export function registerAdd(program: Command) {
272275
.option('--discovery-url <url>', 'OIDC discovery URL (required for CUSTOM_JWT)')
273276
.option('--allowed-audience <values>', 'Comma-separated allowed audience values (required for CUSTOM_JWT)')
274277
.option('--allowed-clients <values>', 'Comma-separated allowed client IDs (required for CUSTOM_JWT)')
278+
.option('--allowed-scopes <scopes>', 'Comma-separated allowed scopes (optional for CUSTOM_JWT)')
279+
.option('--agent-client-id <id>', 'Agent OAuth client ID for Bearer token auth (CUSTOM_JWT)')
280+
.option('--agent-client-secret <secret>', 'Agent OAuth client secret (CUSTOM_JWT)')
275281
.option('--json', 'Output as JSON')
276282
.action(async options => {
277283
requireProject();

src/cli/commands/add/types.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ export interface AddGatewayOptions {
3131
discoveryUrl?: string;
3232
allowedAudience?: string;
3333
allowedClients?: string;
34+
allowedScopes?: string;
35+
agentClientId?: string;
36+
agentClientSecret?: string;
3437
agents?: string;
3538
json?: boolean;
3639
}

0 commit comments

Comments
 (0)