Skip to content

Commit 866811a

Browse files
committed
feat(mcp): add credential_key shorthand for Bearer token propagation
Adds `credential_key` parameter to McpToolset and McpToolsetConfig as a convenience shorthand that reads a token from session state and sends it as an `Authorization: Bearer <token>` header. Internally creates a header_provider, so it combines cleanly with existing header_provider and state_header_mapping options. Aligns with credential_key usage on other ADK toolsets (e.g., OpenAPIToolset). Addresses #5103.
1 parent 68180ed commit 866811a

File tree

2 files changed

+159
-6
lines changed

2 files changed

+159
-6
lines changed

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def __init__(
253253
auth_credential: Optional[AuthCredential] = None,
254254
require_confirmation: Union[bool, Callable[..., bool]] = False,
255255
header_provider: Optional[HeaderProvider] = None,
256+
credential_key: Optional[str] = None,
256257
progress_callback: Optional[
257258
Union[ProgressFnT, ProgressCallbackFactory]
258259
] = None,
@@ -283,6 +284,11 @@ def __init__(
283284
Can be a single boolean or a callable to apply to all tools.
284285
header_provider: A callable that takes a ReadonlyContext and returns a
285286
dictionary of headers to be used for the MCP session.
287+
credential_key: A session state key whose value is sent as an
288+
``Authorization: Bearer <token>`` header on every MCP request. This is
289+
a convenience shorthand that internally creates a header_provider. If
290+
both ``credential_key`` and ``header_provider`` are provided, they are
291+
combined.
286292
progress_callback: Optional callback to receive progress notifications
287293
from MCP server during long-running tool execution. Can be either: - A
288294
``ProgressFnT`` callback that receives (progress, total, message). This
@@ -310,7 +316,20 @@ def __init__(
310316

311317
self._connection_params = connection_params
312318
self._errlog = errlog
313-
self._header_provider = header_provider
319+
320+
# Build the effective header_provider from credential_key and/or the
321+
# explicit header_provider parameter.
322+
credential_provider = (
323+
create_session_state_header_provider(state_key=credential_key)
324+
if credential_key
325+
else None
326+
)
327+
if credential_provider and header_provider:
328+
self._header_provider = create_combined_header_provider(
329+
[credential_provider, header_provider]
330+
)
331+
else:
332+
self._header_provider = credential_provider or header_provider
314333
self._progress_callback = progress_callback
315334

316335
# Create the session manager that will handle the MCP connection
@@ -580,13 +599,14 @@ def from_config(
580599
else:
581600
raise ValueError("No connection params found in McpToolsetConfig.")
582601

583-
# Create header_provider from state_header_mapping if specified
584-
header_provider = None
602+
# Build header_provider from state_header_mapping and/or credential_key.
603+
providers = []
604+
585605
if mcp_toolset_config.state_header_mapping:
586606
state_mapping = mcp_toolset_config.state_header_mapping
587607
state_format = mcp_toolset_config.state_header_format or {}
588608

589-
providers = [
609+
providers.extend([
590610
create_session_state_header_provider(
591611
state_key=state_key,
592612
header_name=header_name,
@@ -595,9 +615,18 @@ def from_config(
595615
strict=mcp_toolset_config.state_header_strict,
596616
)
597617
for state_key, header_name in state_mapping.items()
598-
]
618+
])
599619

600-
header_provider = create_combined_header_provider(providers)
620+
if mcp_toolset_config.credential_key:
621+
providers.append(
622+
create_session_state_header_provider(
623+
state_key=mcp_toolset_config.credential_key,
624+
)
625+
)
626+
627+
header_provider = (
628+
create_combined_header_provider(providers) if providers else None
629+
)
601630

602631
return cls(
603632
connection_params=connection_params,
@@ -645,6 +674,12 @@ class McpToolsetConfig(BaseToolConfig):
645674

646675
use_mcp_resources: bool = False
647676

677+
credential_key: Optional[str] = None
678+
"""A session state key whose value is sent as an ``Authorization: Bearer``
679+
header on every MCP HTTP request. Convenience shorthand that is equivalent
680+
to ``state_header_mapping: {<key>: Authorization}`` with
681+
``state_header_format: {Authorization: "Bearer {value}"}``."""
682+
648683
state_header_mapping: Optional[Dict[str, str]] = None
649684
"""Maps session state keys to HTTP header names.
650685

tests/unittests/tools/mcp_tool/test_jwt_token_propagation.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,53 @@ def test_from_config_no_state_mapping_no_provider(self):
429429
# No header provider should be created
430430
assert toolset._header_provider is None
431431

432+
def test_from_config_with_credential_key(self):
433+
"""Test that from_config creates header provider from credential_key."""
434+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
435+
from google.adk.tools.tool_configs import ToolArgsConfig
436+
437+
config = ToolArgsConfig(
438+
stdio_server_params={"command": "test_command", "args": []},
439+
credential_key="my_token",
440+
)
441+
442+
toolset = McpToolset.from_config(config, "/fake/path")
443+
444+
assert toolset._header_provider is not None
445+
446+
mock_context = Mock(spec=ReadonlyContext)
447+
mock_context.state = {"my_token": "test-jwt-123"}
448+
449+
headers = toolset._header_provider(mock_context)
450+
451+
assert headers == {"Authorization": "Bearer test-jwt-123"}
452+
453+
def test_from_config_credential_key_with_state_header_mapping(self):
454+
"""Test that credential_key and state_header_mapping combine."""
455+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
456+
from google.adk.tools.tool_configs import ToolArgsConfig
457+
458+
config = ToolArgsConfig(
459+
stdio_server_params={"command": "test_command", "args": []},
460+
credential_key="jwt_token",
461+
state_header_mapping={"tenant_id": "X-Tenant-ID"},
462+
)
463+
464+
toolset = McpToolset.from_config(config, "/fake/path")
465+
466+
mock_context = Mock(spec=ReadonlyContext)
467+
mock_context.state = {
468+
"jwt_token": "my-jwt",
469+
"tenant_id": "tenant-42",
470+
}
471+
472+
headers = toolset._header_provider(mock_context)
473+
474+
assert headers == {
475+
"Authorization": "Bearer my-jwt",
476+
"X-Tenant-ID": "tenant-42",
477+
}
478+
432479
def test_from_config_with_strict_mode(self):
433480
"""Test that from_config respects state_header_strict setting."""
434481
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
@@ -453,6 +500,77 @@ def test_from_config_with_strict_mode(self):
453500
assert "dict" in str(exc_info.value)
454501

455502

503+
class TestCredentialKey:
504+
"""Test suite for credential_key on McpToolset.__init__."""
505+
506+
def test_credential_key_creates_bearer_header(self):
507+
"""Test credential_key reads token from state and sends as Bearer."""
508+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
509+
510+
toolset = McpToolset(
511+
connection_params=StdioServerParameters(command="echo", args=[]),
512+
credential_key="auth_token",
513+
)
514+
515+
assert toolset._header_provider is not None
516+
517+
mock_context = Mock(spec=ReadonlyContext)
518+
mock_context.state = {"auth_token": "my-jwt-token"}
519+
520+
headers = toolset._header_provider(mock_context)
521+
522+
assert headers == {"Authorization": "Bearer my-jwt-token"}
523+
524+
def test_credential_key_missing_state_returns_empty(self):
525+
"""Test credential_key returns empty headers when key not in state."""
526+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
527+
528+
toolset = McpToolset(
529+
connection_params=StdioServerParameters(command="echo", args=[]),
530+
credential_key="auth_token",
531+
)
532+
533+
mock_context = Mock(spec=ReadonlyContext)
534+
mock_context.state = {}
535+
536+
headers = toolset._header_provider(mock_context)
537+
538+
assert headers == {}
539+
540+
def test_credential_key_none_means_no_provider(self):
541+
"""Test that credential_key=None does not create a provider."""
542+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
543+
544+
toolset = McpToolset(
545+
connection_params=StdioServerParameters(command="echo", args=[]),
546+
credential_key=None,
547+
)
548+
549+
assert toolset._header_provider is None
550+
551+
def test_credential_key_combines_with_header_provider(self):
552+
"""Test that credential_key and header_provider are combined."""
553+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
554+
555+
custom_provider = lambda ctx: {"X-Custom": "value"}
556+
557+
toolset = McpToolset(
558+
connection_params=StdioServerParameters(command="echo", args=[]),
559+
credential_key="auth_token",
560+
header_provider=custom_provider,
561+
)
562+
563+
mock_context = Mock(spec=ReadonlyContext)
564+
mock_context.state = {"auth_token": "my-jwt"}
565+
566+
headers = toolset._header_provider(mock_context)
567+
568+
assert headers == {
569+
"Authorization": "Bearer my-jwt",
570+
"X-Custom": "value",
571+
}
572+
573+
456574
class TestRFC7230Compliance:
457575
"""Test suite for RFC 7230 compliant header handling."""
458576

0 commit comments

Comments
 (0)