Skip to content

Commit e7f43fb

Browse files
committed
fix(mcp): consolidate header sanitization and clean up review findings
- Move sanitization from _get_auth_headers to _execute_with_session boundary - Replace verbose _DANGEROUS_CHARS literal with frozenset comprehension - Remove redundant CRLF strip already handled by sanitize_header_value - Use lazy %-formatting instead of f-strings in logging - Clarify docstrings for run() Optional new_message and from_config()
1 parent b2bc997 commit e7f43fb

3 files changed

Lines changed: 20 additions & 49 deletions

File tree

src/google/adk/runners.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,9 @@ def run(
458458
Args:
459459
user_id: The user ID of the session.
460460
session_id: The session ID of the session.
461-
new_message: A new message to append to the session.
461+
new_message: An optional new message to append to the session. When
462+
omitted, either an invocation_id must be provided to resume a
463+
previous invocation, or the app must be configured as resumable.
462464
invocation_id: The invocation id to resume.
463465
state_delta: Optional state changes to apply to the session.
464466
request_state: Optional ephemeral state for the request.

src/google/adk/tools/mcp_tool/_internal.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -48,43 +48,9 @@
4848
# RFC 7230 compliant header name pattern (allows letters, digits, hyphens)
4949
_HEADER_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9-]+\Z")
5050

51-
# Truly dangerous characters that should never appear in header values
52-
# These are characters that can break HTTP parsing or cause injection
53-
_DANGEROUS_CHARS = {
54-
"\x00",
55-
"\x01",
56-
"\x02",
57-
"\x03",
58-
"\x04",
59-
"\x05",
60-
"\x06",
61-
"\x07",
62-
"\x08",
63-
"\x09",
64-
"\x0a",
65-
"\x0b",
66-
"\x0c",
67-
"\x0d",
68-
"\x0e",
69-
"\x0f",
70-
"\x10",
71-
"\x11",
72-
"\x12",
73-
"\x13",
74-
"\x14",
75-
"\x15",
76-
"\x16",
77-
"\x17",
78-
"\x18",
79-
"\x19",
80-
"\x1a",
81-
"\x1b",
82-
"\x1c",
83-
"\x1d",
84-
"\x1e",
85-
"\x1f",
86-
"\x7f",
87-
}
51+
# ASCII control characters (0x00-0x1F) and DEL (0x7F) that must never
52+
# appear in HTTP header values — prevents injection and response splitting.
53+
_DANGEROUS_CHARS = frozenset(chr(i) for i in range(0x20)) | frozenset({chr(0x7F)})
8854

8955

9056
def validate_header_name(header_name: str) -> None:

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,6 @@ def provider(ctx: ReadonlyContext) -> Dict[str, str]:
163163

164164
validate_header_value(state_key, value, strict=strict)
165165
formatted_value = header_format.format(value=value)
166-
# Defense in depth: strip CRLF before sanitization. sanitize_header_value
167-
# also strips these, but we strip early to prevent injection via format
168-
# string interpolation.
169-
formatted_value = formatted_value.replace("\r", "").replace("\n", "")
170166
sanitized_value = sanitize_header_value(formatted_value)
171167

172168
return {header_name: sanitized_value}
@@ -204,12 +200,13 @@ def combined_provider(ctx: ReadonlyContext) -> Dict[str, str]:
204200
)
205201
headers.update(provider_headers)
206202
except Exception as e:
207-
logger.error(f"Header provider {i+1}/{num_providers} failed: {e}")
203+
logger.error("Header provider %d/%d failed: %s", i + 1, num_providers, e)
208204
raise
209205

210206
if headers:
211207
logger.debug(
212-
f"Combined header provider generated {len(headers)} total headers"
208+
"Combined header provider generated %d total headers",
209+
len(headers),
213210
)
214211
return headers
215212

@@ -444,10 +441,6 @@ def _get_auth_headers(
444441
# Default to using scheme name as header
445442
headers = {self._auth_config.auth_scheme.name: credential.api_key}
446443

447-
# Sanitize all header values to prevent injection attacks.
448-
if headers:
449-
headers = {k: sanitize_header_value(v) for k, v in headers.items()}
450-
451444
return headers
452445

453446
async def _execute_with_session(
@@ -470,6 +463,10 @@ async def _execute_with_session(
470463
if provider_headers:
471464
headers.update(provider_headers)
472465

466+
# Sanitize all header values at the boundary to prevent injection.
467+
if headers:
468+
headers = {k: sanitize_header_value(v) for k, v in headers.items()}
469+
473470
session = await self._mcp_session_manager.create_session(
474471
headers=headers if headers else None
475472
)
@@ -611,7 +608,13 @@ def get_auth_config(self) -> Optional[AuthConfig]:
611608
def from_config(
612609
cls: type[McpToolset], config: ToolArgsConfig, config_abs_path: str
613610
) -> McpToolset:
614-
"""Creates an McpToolset from a configuration object."""
611+
"""Creates an McpToolset from a configuration object.
612+
613+
Note: This method constructs the header_provider from the declarative
614+
state_header_mapping in the config. Since McpToolsetConfig is a
615+
serializable Pydantic model, it cannot hold callable objects. To use a
616+
custom header_provider, construct McpToolset directly.
617+
"""
615618
mcp_toolset_config = McpToolsetConfig.model_validate(config.model_dump())
616619

617620
if mcp_toolset_config.stdio_server_params:

0 commit comments

Comments
 (0)