Skip to content

Commit 34aadc4

Browse files
committed
fix(mcp): harden header handling and add config validation
- Add CRLF/TAB to _DANGEROUS_CHARS and sanitize auth headers - Standardize header merge order (header_provider takes precedence) - Add model validator for state_header_mapping/state_header_format - Warn on duplicate header names in combined providers - Add request_state/state_delta/invocation_id to sync Runner.run() - Use TYPE_CHECKING guard in types.py
1 parent 31a2981 commit 34aadc4

5 files changed

Lines changed: 191 additions & 21 deletions

File tree

src/google/adk/runners.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,10 @@ def run(
439439
*,
440440
user_id: str,
441441
session_id: str,
442-
new_message: types.Content,
442+
new_message: Optional[types.Content] = None,
443+
invocation_id: Optional[str] = None,
444+
state_delta: Optional[dict[str, Any]] = None,
445+
request_state: Optional[dict[str, Any]] = None,
443446
run_config: Optional[RunConfig] = None,
444447
) -> Generator[Event, None, None]:
445448
"""Runs the agent.
@@ -457,6 +460,9 @@ def run(
457460
user_id: The user ID of the session.
458461
session_id: The session ID of the session.
459462
new_message: A new message to append to the session.
463+
invocation_id: The invocation id to resume.
464+
state_delta: Optional state changes to apply to the session.
465+
request_state: Optional ephemeral state for the request.
460466
run_config: The run config for the agent.
461467
462468
Yields:
@@ -472,6 +478,9 @@ async def _invoke_run_async():
472478
user_id=user_id,
473479
session_id=session_id,
474480
new_message=new_message,
481+
invocation_id=invocation_id,
482+
state_delta=state_delta,
483+
request_state=request_state,
475484
run_config=run_config,
476485
)
477486
) as agen:
@@ -561,6 +570,10 @@ async def _run_with_trace(
561570
is_resumable = (
562571
self.resumability_config and self.resumability_config.is_resumable
563572
)
573+
# Three-branch decision tree:
574+
# A) invocation_id provided → resume that specific invocation
575+
# B) not resumable → must start a new invocation (requires new_message)
576+
# C) resumable, no explicit invocation_id → resolve or create new
564577
if invocation_id:
565578
if not is_resumable:
566579
raise ValueError(
@@ -614,8 +627,9 @@ async def _run_with_trace(
614627
if invocation_context.end_of_agents.get(
615628
invocation_context.agent.name
616629
):
617-
# Directly return if the current agent in invocation context is
618-
# already final.
630+
# Agent already completed in a prior invocation — skip execution
631+
# and return immediately. This can happen when resuming a
632+
# completed agent in a multi-agent pipeline.
619633
return
620634

621635
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:

src/google/adk/tools/mcp_tool/_internal.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
**Security Notes:**
2121
2222
- Header validation implements RFC 7230 §3.2 for proper HTTP header format
23-
- Only truly dangerous control characters are removed from header values
23+
- All ASCII control characters (0x00-0x1F) and DEL (0x7F) are removed from
24+
header values to prevent injection
2425
- All functions log security-relevant warnings when appropriate
2526
2627
**RFC 7230 Compliance:**
2728
2829
- Header names: only letters, digits, and hyphens allowed
29-
- Header values: control characters (0x00-0x1F, 0x7F) are dangerous
30+
- Header values: control characters including CRLF (0x00-0x1F, 0x7F) are
31+
removed to prevent injection
3032
3133
**Attack Prevention:**
3234
@@ -58,8 +60,11 @@
5860
"\x06",
5961
"\x07",
6062
"\x08",
63+
"\x09",
64+
"\x0a",
6165
"\x0b",
6266
"\x0c",
67+
"\x0d",
6368
"\x0e",
6469
"\x0f",
6570
"\x10",
@@ -133,10 +138,10 @@ def sanitize_header_value(value: Any) -> str:
133138
if not isinstance(value, str):
134139
value = str(value)
135140

136-
# Remove only characters that are truly dangerous for HTTP headers
137-
# These are control characters that can break parsing or enable injection
138-
# We DON'T remove all \r\n sequences as that would break legitimate multi-line headers
139-
# and violate RFC 7230 §3.2.4 which allows header folding
141+
# Remove CRLF and control characters to prevent header injection.
142+
# Header folding (obs-fold) was deprecated by RFC 7230 and obsoleted
143+
# by RFC 9110. CRLF in header values is the primary vector for
144+
# header injection and response splitting attacks.
140145
sanitized_chars = []
141146
for char in value:
142147
if char not in _DANGEROUS_CHARS:

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ def provider(ctx: ReadonlyContext) -> Dict[str, str]:
163163

164164
validate_header_value(state_key, value, strict=strict)
165165
formatted_value = header_format.format(value=value)
166-
# Strip CRLF from the interpolated value to prevent header injection.
167-
# The format string is validated at construction time, but the runtime
168-
# value comes from session state and must never contain CRLF here.
166+
# Defense in depth: strip CRLF before sanitization. sanitize_header_value
167+
# also strips these, but we strip early to prevent injection via format
168+
# string interpolation.
169169
formatted_value = formatted_value.replace("\r", "").replace("\n", "")
170170
sanitized_value = sanitize_header_value(formatted_value)
171171

@@ -193,6 +193,15 @@ def combined_provider(ctx: ReadonlyContext) -> Dict[str, str]:
193193
try:
194194
provider_headers = provider(ctx)
195195
if provider_headers:
196+
overlapping = set(headers.keys()) & set(provider_headers.keys())
197+
if overlapping:
198+
logger.warning(
199+
"Duplicate header names %s from header provider "
200+
"%d/%d. Last value wins.",
201+
overlapping,
202+
i + 1,
203+
num_providers,
204+
)
196205
headers.update(provider_headers)
197206
except Exception as e:
198207
logger.error(f"Header provider {i+1}/{num_providers} failed: {e}")
@@ -301,7 +310,9 @@ def __init__(
301310
MCP server.
302311
sampling_capabilities: Optional capabilities for sampling.
303312
credential_key: A user specified key used to load and save this credential
304-
in a credential service. Used with auth_scheme.
313+
in a credential service. Used with auth_scheme. Note: when both
314+
credential_key and header_provider are configured, header_provider
315+
values take precedence over auth headers for the same header names.
305316
"""
306317

307318
# --- BEGIN BOUND TOKEN PATCH ---
@@ -433,6 +444,10 @@ def _get_auth_headers(
433444
# Default to using scheme name as header
434445
headers = {self._auth_config.auth_scheme.name: credential.api_key}
435446

447+
# Sanitize all header values to prevent injection attacks.
448+
if headers:
449+
headers = {k: sanitize_header_value(v) for k, v in headers.items()}
450+
436451
return headers
437452

438453
async def _execute_with_session(
@@ -444,17 +459,17 @@ async def _execute_with_session(
444459
"""Creates a session and executes a coroutine with it."""
445460
headers: Dict[str, str] = {}
446461

447-
# Add headers from header_provider if available
462+
# Add auth headers from exchanged credential first
463+
auth_headers = self._get_auth_headers(readonly_context)
464+
if auth_headers:
465+
headers.update(auth_headers)
466+
467+
# Add headers from header_provider (takes precedence over auth headers)
448468
if self._header_provider and readonly_context:
449469
provider_headers = self._header_provider(readonly_context)
450470
if provider_headers:
451471
headers.update(provider_headers)
452472

453-
# Add auth headers from exchanged credential if available
454-
auth_headers = self._get_auth_headers(readonly_context)
455-
if auth_headers:
456-
headers.update(auth_headers)
457-
458473
session = await self._mcp_session_manager.create_session(
459474
headers=headers if headers else None
460475
)
@@ -761,3 +776,33 @@ def _check_only_one_params_field(self):
761776
" set."
762777
)
763778
return self
779+
780+
@model_validator(mode="after")
781+
def _validate_state_header_config(self):
782+
"""Validates state_header_mapping and state_header_format consistency."""
783+
if not self.state_header_mapping:
784+
if self.state_header_format:
785+
raise ValueError(
786+
"state_header_format cannot be set without state_header_mapping."
787+
)
788+
return self
789+
790+
# Validate header names in state_header_mapping values
791+
for state_key, header_name in self.state_header_mapping.items():
792+
validate_header_name(header_name)
793+
794+
# Validate state_header_format keys match header names
795+
if self.state_header_format:
796+
header_names = set(self.state_header_mapping.values())
797+
for format_key in self.state_header_format:
798+
if format_key not in header_names:
799+
raise ValueError(
800+
f'state_header_format key "{format_key}" does not match'
801+
" any header name in state_header_mapping values."
802+
f" Expected one of: {sorted(header_names)}"
803+
)
804+
# Validate format strings don't contain CRLF
805+
for header_name, fmt in self.state_header_format.items():
806+
validate_header_format(fmt)
807+
808+
return self

src/google/adk/tools/mcp_tool/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
from typing import Callable
1818
from typing import Dict
19+
from typing import TYPE_CHECKING
1920

20-
from ...agents.readonly_context import ReadonlyContext
21+
if TYPE_CHECKING:
22+
from ...agents.readonly_context import ReadonlyContext
2123

22-
HeaderProvider = Callable[[ReadonlyContext], Dict[str, str]]
24+
HeaderProvider = Callable[["ReadonlyContext"], Dict[str, str]]

tests/unittests/tools/mcp_tool/test_jwt_token_propagation.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Tests for JWT token propagation feature in MCP toolset."""
1616

17+
import logging
1718
import sys
1819
from unittest.mock import Mock
1920

@@ -538,6 +539,10 @@ def test_header_value_sanitization_rfc_compliant(self):
538539
("nack\x15syn\x16etb\x17", "nacksynetb"), # NAK, SYN, ETB
539540
("can\x18em\x19sub\x1Aesc", "canemsubesc"), # CAN, EM, SUB, ESC
540541
("fs\x1ags\x1brs\x1cus", "fsgsrsus"), # FS, GS, RS, US
542+
("tok\r\nX-Injected: evil", "tokX-Injected: evil"), # CRLF
543+
("tok\nInjected: bad", "tokInjected: bad"), # LF
544+
("tok\rAnother: header", "tokAnother: header"), # CR
545+
("tab\ttest", "tabtest"), # TAB should be removed
541546
("space\x20test", "space test"), # Space should be preserved
542547
(
543548
"normal!@#$%^&*()test",
@@ -671,3 +676,102 @@ def test_header_format_crlf_injection_protection(self):
671676
header_name="Authorization",
672677
header_format="Bearer {value}\r\nX-Injected: evil",
673678
)
679+
680+
681+
class TestMcpToolsetConfigValidation:
682+
"""Test suite for McpToolsetConfig state header validation."""
683+
684+
def test_format_without_mapping_raises(self):
685+
"""Test that state_header_format without mapping raises ValueError."""
686+
with pytest.raises(ValueError, match="state_header_format cannot be set"):
687+
McpToolsetConfig(
688+
stdio_server_params=StdioServerParameters(
689+
command="test_command", args=[]
690+
),
691+
state_header_format={"Authorization": "Bearer {value}"},
692+
)
693+
694+
def test_format_key_not_in_mapping_values_raises(self):
695+
"""Test that format key not matching any mapping value raises."""
696+
with pytest.raises(ValueError, match="does not match"):
697+
McpToolsetConfig(
698+
stdio_server_params=StdioServerParameters(
699+
command="test_command", args=[]
700+
),
701+
state_header_mapping={"jwt_token": "Authorization"},
702+
state_header_format={"X-Wrong-Header": "Bearer {value}"},
703+
)
704+
705+
def test_invalid_header_name_in_mapping_raises(self):
706+
"""Test that invalid header name in mapping value raises ValueError."""
707+
with pytest.raises(ValueError, match="invalid characters"):
708+
McpToolsetConfig(
709+
stdio_server_params=StdioServerParameters(
710+
command="test_command", args=[]
711+
),
712+
state_header_mapping={"jwt_token": "Authorization\n"},
713+
)
714+
715+
def test_crlf_in_format_value_raises(self):
716+
"""Test that CRLF in format string raises ValueError."""
717+
with pytest.raises(ValueError, match="CRLF"):
718+
McpToolsetConfig(
719+
stdio_server_params=StdioServerParameters(
720+
command="test_command", args=[]
721+
),
722+
state_header_mapping={"jwt_token": "Authorization"},
723+
state_header_format={
724+
"Authorization": "Bearer {value}\r\nX-Injected: evil"
725+
},
726+
)
727+
728+
def test_valid_config_passes_validation(self):
729+
"""Test that valid config passes all validation."""
730+
config = McpToolsetConfig(
731+
stdio_server_params=StdioServerParameters(
732+
command="test_command", args=[]
733+
),
734+
state_header_mapping={
735+
"jwt_token": "Authorization",
736+
"tenant_id": "X-Tenant-ID",
737+
},
738+
state_header_format={"Authorization": "Bearer {value}"},
739+
)
740+
assert config.state_header_mapping is not None
741+
742+
743+
class TestCombinedHeaderProviderDuplicateWarning:
744+
"""Test suite for duplicate header warning in combined provider."""
745+
746+
def test_warns_on_duplicate_headers(self, caplog):
747+
"""Test that duplicate header names trigger a warning."""
748+
from google.adk.tools.mcp_tool.mcp_toolset import create_combined_header_provider
749+
750+
provider1 = lambda ctx: {"Authorization": "Bearer token1"} # noqa: E731
751+
provider2 = lambda ctx: {"Authorization": "Bearer token2"} # noqa: E731
752+
753+
combined = create_combined_header_provider([provider1, provider2])
754+
755+
with caplog.at_level(logging.WARNING, logger="google_adk"):
756+
headers = combined(Mock(spec=ReadonlyContext))
757+
758+
assert headers["Authorization"] == "Bearer token2"
759+
assert "Duplicate header names" in caplog.text
760+
761+
def test_no_warning_without_duplicates(self, caplog):
762+
"""Test that no warning is logged when headers don't overlap."""
763+
from google.adk.tools.mcp_tool.mcp_toolset import create_combined_header_provider
764+
765+
provider1 = lambda ctx: {"Authorization": "Bearer token1"} # noqa: E731
766+
provider2 = lambda ctx: {"X-Tenant-ID": "tenant-123"} # noqa: E731
767+
768+
combined = create_combined_header_provider([provider1, provider2])
769+
770+
with caplog.at_level(logging.WARNING, logger="google_adk"):
771+
headers = combined(Mock(spec=ReadonlyContext))
772+
773+
assert "Duplicate header names" not in caplog.text
774+
assert headers == {
775+
"Authorization": "Bearer token1",
776+
"X-Tenant-ID": "tenant-123",
777+
}

0 commit comments

Comments
 (0)