Skip to content

Commit 66208c0

Browse files
committed
feat(mcp): add credential_key shorthand for Bearer token propagation
Adds `credential_key` parameter to McpToolset and McpToolsetConfig as a convenience shorthand that reads a token from session state and sends it as an `Authorization: Bearer <token>` header. Internally creates a header_provider, so it combines cleanly with existing header_provider and state_header_mapping options. Aligns with credential_key usage on other ADK toolsets (e.g., OpenAPIToolset). Addresses #5103.
1 parent 0a7573b commit 66208c0

2 files changed

Lines changed: 159 additions & 6 deletions

File tree

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def __init__(
254254
auth_credential: Optional[AuthCredential] = None,
255255
require_confirmation: Union[bool, Callable[..., bool]] = False,
256256
header_provider: Optional[HeaderProvider] = None,
257+
credential_key: Optional[str] = None,
257258
progress_callback: Optional[
258259
Union[ProgressFnT, ProgressCallbackFactory]
259260
] = None,
@@ -284,6 +285,11 @@ def __init__(
284285
Can be a single boolean or a callable to apply to all tools.
285286
header_provider: A callable that takes a ReadonlyContext and returns a
286287
dictionary of headers to be used for the MCP session.
288+
credential_key: A session state key whose value is sent as an
289+
``Authorization: Bearer <token>`` header on every MCP request. This is
290+
a convenience shorthand that internally creates a header_provider. If
291+
both ``credential_key`` and ``header_provider`` are provided, they are
292+
combined.
287293
progress_callback: Optional callback to receive progress notifications
288294
from MCP server during long-running tool execution. Can be either: - A
289295
``ProgressFnT`` callback that receives (progress, total, message). This
@@ -320,7 +326,20 @@ def __init__(
320326

321327
self._connection_params = connection_params
322328
self._errlog = errlog
323-
self._header_provider = header_provider
329+
330+
# Build the effective header_provider from credential_key and/or the
331+
# explicit header_provider parameter.
332+
credential_provider = (
333+
create_session_state_header_provider(state_key=credential_key)
334+
if credential_key
335+
else None
336+
)
337+
if credential_provider and header_provider:
338+
self._header_provider = create_combined_header_provider(
339+
[credential_provider, header_provider]
340+
)
341+
else:
342+
self._header_provider = credential_provider or header_provider
324343
self._progress_callback = progress_callback
325344

326345
# Create the session manager that will handle the MCP connection
@@ -605,13 +624,14 @@ def from_config(
605624
else:
606625
raise ValueError("No connection params found in McpToolsetConfig.")
607626

608-
# Create header_provider from state_header_mapping if specified
609-
header_provider = None
627+
# Build header_provider from state_header_mapping and/or credential_key.
628+
providers = []
629+
610630
if mcp_toolset_config.state_header_mapping:
611631
state_mapping = mcp_toolset_config.state_header_mapping
612632
state_format = mcp_toolset_config.state_header_format or {}
613633

614-
providers = [
634+
providers.extend([
615635
create_session_state_header_provider(
616636
state_key=state_key,
617637
header_name=header_name,
@@ -620,9 +640,18 @@ def from_config(
620640
strict=mcp_toolset_config.state_header_strict,
621641
)
622642
for state_key, header_name in state_mapping.items()
623-
]
643+
])
624644

625-
header_provider = create_combined_header_provider(providers)
645+
if mcp_toolset_config.credential_key:
646+
providers.append(
647+
create_session_state_header_provider(
648+
state_key=mcp_toolset_config.credential_key,
649+
)
650+
)
651+
652+
header_provider = (
653+
create_combined_header_provider(providers) if providers else None
654+
)
626655

627656
return cls(
628657
connection_params=connection_params,
@@ -684,6 +713,12 @@ class McpToolsetConfig(BaseToolConfig):
684713

685714
use_mcp_resources: bool = False
686715

716+
credential_key: Optional[str] = None
717+
"""A session state key whose value is sent as an ``Authorization: Bearer``
718+
header on every MCP HTTP request. Convenience shorthand that is equivalent
719+
to ``state_header_mapping: {<key>: Authorization}`` with
720+
``state_header_format: {Authorization: "Bearer {value}"}``."""
721+
687722
state_header_mapping: Optional[Dict[str, str]] = None
688723
"""Maps session state keys to HTTP header names.
689724

tests/unittests/tools/mcp_tool/test_jwt_token_propagation.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,53 @@ def test_from_config_no_state_mapping_no_provider(self):
429429
# No header provider should be created
430430
assert toolset._header_provider is None
431431

432+
def test_from_config_with_credential_key(self):
433+
"""Test that from_config creates header provider from credential_key."""
434+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
435+
from google.adk.tools.tool_configs import ToolArgsConfig
436+
437+
config = ToolArgsConfig(
438+
stdio_server_params={"command": "test_command", "args": []},
439+
credential_key="my_token",
440+
)
441+
442+
toolset = McpToolset.from_config(config, "/fake/path")
443+
444+
assert toolset._header_provider is not None
445+
446+
mock_context = Mock(spec=ReadonlyContext)
447+
mock_context.state = {"my_token": "test-jwt-123"}
448+
449+
headers = toolset._header_provider(mock_context)
450+
451+
assert headers == {"Authorization": "Bearer test-jwt-123"}
452+
453+
def test_from_config_credential_key_with_state_header_mapping(self):
454+
"""Test that credential_key and state_header_mapping combine."""
455+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
456+
from google.adk.tools.tool_configs import ToolArgsConfig
457+
458+
config = ToolArgsConfig(
459+
stdio_server_params={"command": "test_command", "args": []},
460+
credential_key="jwt_token",
461+
state_header_mapping={"tenant_id": "X-Tenant-ID"},
462+
)
463+
464+
toolset = McpToolset.from_config(config, "/fake/path")
465+
466+
mock_context = Mock(spec=ReadonlyContext)
467+
mock_context.state = {
468+
"jwt_token": "my-jwt",
469+
"tenant_id": "tenant-42",
470+
}
471+
472+
headers = toolset._header_provider(mock_context)
473+
474+
assert headers == {
475+
"Authorization": "Bearer my-jwt",
476+
"X-Tenant-ID": "tenant-42",
477+
}
478+
432479
def test_from_config_with_strict_mode(self):
433480
"""Test that from_config respects state_header_strict setting."""
434481
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
@@ -453,6 +500,77 @@ def test_from_config_with_strict_mode(self):
453500
assert "dict" in str(exc_info.value)
454501

455502

503+
class TestCredentialKey:
504+
"""Test suite for credential_key on McpToolset.__init__."""
505+
506+
def test_credential_key_creates_bearer_header(self):
507+
"""Test credential_key reads token from state and sends as Bearer."""
508+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
509+
510+
toolset = McpToolset(
511+
connection_params=StdioServerParameters(command="echo", args=[]),
512+
credential_key="auth_token",
513+
)
514+
515+
assert toolset._header_provider is not None
516+
517+
mock_context = Mock(spec=ReadonlyContext)
518+
mock_context.state = {"auth_token": "my-jwt-token"}
519+
520+
headers = toolset._header_provider(mock_context)
521+
522+
assert headers == {"Authorization": "Bearer my-jwt-token"}
523+
524+
def test_credential_key_missing_state_returns_empty(self):
525+
"""Test credential_key returns empty headers when key not in state."""
526+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
527+
528+
toolset = McpToolset(
529+
connection_params=StdioServerParameters(command="echo", args=[]),
530+
credential_key="auth_token",
531+
)
532+
533+
mock_context = Mock(spec=ReadonlyContext)
534+
mock_context.state = {}
535+
536+
headers = toolset._header_provider(mock_context)
537+
538+
assert headers == {}
539+
540+
def test_credential_key_none_means_no_provider(self):
541+
"""Test that credential_key=None does not create a provider."""
542+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
543+
544+
toolset = McpToolset(
545+
connection_params=StdioServerParameters(command="echo", args=[]),
546+
credential_key=None,
547+
)
548+
549+
assert toolset._header_provider is None
550+
551+
def test_credential_key_combines_with_header_provider(self):
552+
"""Test that credential_key and header_provider are combined."""
553+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
554+
555+
custom_provider = lambda ctx: {"X-Custom": "value"}
556+
557+
toolset = McpToolset(
558+
connection_params=StdioServerParameters(command="echo", args=[]),
559+
credential_key="auth_token",
560+
header_provider=custom_provider,
561+
)
562+
563+
mock_context = Mock(spec=ReadonlyContext)
564+
mock_context.state = {"auth_token": "my-jwt"}
565+
566+
headers = toolset._header_provider(mock_context)
567+
568+
assert headers == {
569+
"Authorization": "Bearer my-jwt",
570+
"X-Custom": "value",
571+
}
572+
573+
456574
class TestRFC7230Compliance:
457575
"""Test suite for RFC 7230 compliant header handling."""
458576

0 commit comments

Comments
 (0)