Skip to content

Commit 13375b4

Browse files
committed
fix(mcp): Remove duplicate create_session_state_header_provider
- Remove duplicate function from _internal.py (kept in mcp_toolset.py) - Inline _sanitize_header_value into sanitize_header_value - Update test imports to use public API Addresses review comment
1 parent 54023d6 commit 13375b4

2 files changed

Lines changed: 27 additions & 97 deletions

File tree

src/google/adk/tools/mcp_tool/_internal.py

Lines changed: 15 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -142,38 +142,6 @@ def validate_header_name(header_name: str) -> None:
142142
)
143143

144144

145-
def _sanitize_header_value(value: str) -> str:
146-
"""Sanitizes a header value to prevent injection attacks.
147-
148-
This function removes ONLY truly dangerous characters that could cause
149-
header injection attacks, while remaining RFC 7230 compliant.
150-
151-
Args:
152-
value: The header value to sanitize.
153-
154-
Returns:
155-
The sanitized header value with dangerous characters removed.
156-
"""
157-
if not isinstance(value, str):
158-
value = str(value)
159-
160-
# Remove only characters that are truly dangerous for HTTP headers
161-
# These are control characters that can break parsing or enable injection
162-
# We DON'T remove all \r\n sequences as that would break legitimate multi-line headers
163-
# and violate RFC 7230 §3.2.4 which allows header folding
164-
sanitized_chars = []
165-
for char in value:
166-
if char not in _DANGEROUS_CHARS:
167-
sanitized_chars.append(char)
168-
else:
169-
logger.warning(
170-
f"Removed dangerous character {repr(char)} from header value "
171-
"for security reasons"
172-
)
173-
174-
return "".join(sanitized_chars)
175-
176-
177145
def _validate_header_value(value: Any, allow_binary: bool = False) -> None:
178146
"""Validates header values with RFC 7230 compliance and proper binary handling.
179147
@@ -236,7 +204,21 @@ def sanitize_header_value(value: Any) -> str:
236204
if not isinstance(value, str):
237205
value = str(value)
238206

239-
return _sanitize_header_value(value)
207+
# Remove only characters that are truly dangerous for HTTP headers
208+
# These are control characters that can break parsing or enable injection
209+
# We DON'T remove all \r\n sequences as that would break legitimate multi-line headers
210+
# and violate RFC 7230 §3.2.4 which allows header folding
211+
sanitized_chars = []
212+
for char in value:
213+
if char not in _DANGEROUS_CHARS:
214+
sanitized_chars.append(char)
215+
else:
216+
logger.warning(
217+
f"Removed dangerous character {repr(char)} from header value "
218+
"for security reasons"
219+
)
220+
221+
return "".join(sanitized_chars)
240222

241223

242224
def validate_header_value(
@@ -263,55 +245,3 @@ def validate_header_value(
263245
raise ValueError(msg)
264246
else:
265247
logger.warning(msg)
266-
267-
268-
def create_session_state_header_provider(
269-
state_key: str,
270-
header_name: str = "Authorization",
271-
header_format: str = "Bearer {value}",
272-
default_value: str = None,
273-
strict: bool = False,
274-
):
275-
"""Creates a header provider that extracts values from session state.
276-
277-
This utility function generates a header_provider callable that can be used
278-
with McpToolset to automatically extract values from session state and
279-
format them as HTTP headers for MCP server connections.
280-
281-
.. warning::
282-
**Security Best Practice**: For sensitive, short-lived tokens like JWTs,
283-
use ``request_state`` instead of ``session.state`` to avoid persisting
284-
sensitive data to the database. Pass tokens via
285-
``RunAgentRequest.request_state``, which will override ``session.state``
286-
for the duration of the request without being persisted.
287-
288-
Args:
289-
state_key: The key to look up in session.state (or request_state).
290-
header_name: The HTTP header name to set (default: 'Authorization').
291-
header_format: Format string for the header value. Use {value} as a
292-
placeholder for the state value (default: 'Bearer {value}').
293-
default_value: Default value if state_key is not found in session state.
294-
If None, the header is omitted when the key is missing.
295-
strict: If True, raises ValueError when non-primitive types are
296-
encountered. If False (default), logs a warning instead.
297-
298-
Returns:
299-
A callable that takes a ReadonlyContext and returns a dictionary of
300-
headers to be used for the MCP session.
301-
"""
302-
# Validate header name upfront
303-
validate_header_name(header_name)
304-
305-
def provider(ctx) -> dict[str, str]:
306-
value = ctx.state.get(state_key, default_value)
307-
# Skip header if value is None or empty string
308-
if value is None or value == "":
309-
return {}
310-
311-
validate_header_value(state_key, value, strict=strict)
312-
formatted_value = header_format.format(value=value)
313-
sanitized_value = sanitize_header_value(formatted_value)
314-
315-
return {header_name: sanitized_value}
316-
317-
return provider

tests/unittests/tools/mcp_tool/test_jwt_token_propagation.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def test_header_name_validation_invalid_names(self):
271271

272272
def test_header_value_sanitization_safe_values(self):
273273
"""Test that safe header values are unchanged."""
274-
from google.adk.tools.mcp_tool._internal import _sanitize_header_value
274+
from google.adk.tools.mcp_tool._internal import sanitize_header_value
275275

276276
safe_values = [
277277
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
@@ -281,12 +281,12 @@ def test_header_value_sanitization_safe_values(self):
281281
]
282282

283283
for value in safe_values:
284-
result = _sanitize_header_value(value)
284+
result = sanitize_header_value(value)
285285
assert result == value
286286

287287
def test_header_value_sanitization_dangerous_values(self):
288288
"""Test that dangerous characters are removed from header values."""
289-
from google.adk.tools.mcp_tool._internal import _sanitize_header_value
289+
from google.adk.tools.mcp_tool._internal import sanitize_header_value
290290

291291
dangerous_values = [
292292
("Bearer token\x00injected", "Bearer tokeninjected"),
@@ -296,17 +296,17 @@ def test_header_value_sanitization_dangerous_values(self):
296296
]
297297

298298
for input_val, expected in dangerous_values:
299-
result = _sanitize_header_value(input_val)
299+
result = sanitize_header_value(input_val)
300300
assert result == expected
301301

302302
def test_header_value_sanitization_non_string_values(self):
303303
"""Test that non-string values are converted to string."""
304-
from google.adk.tools.mcp_tool._internal import _sanitize_header_value
304+
from google.adk.tools.mcp_tool._internal import sanitize_header_value
305305

306-
result_int = _sanitize_header_value(123)
306+
result_int = sanitize_header_value(123)
307307
assert result_int == "123"
308308

309-
result_bool = _sanitize_header_value(True)
309+
result_bool = sanitize_header_value(True)
310310
assert result_bool == "True"
311311

312312
def test_session_state_header_provider_with_invalid_header_name(self):
@@ -505,7 +505,7 @@ def test_header_name_validation_rfc_compliant(self):
505505

506506
def test_header_value_sanitization_rfc_compliant(self):
507507
"""Test that header value sanitization is RFC 7230 compliant."""
508-
from google.adk.tools.mcp_tool._internal import _sanitize_header_value
508+
from google.adk.tools.mcp_tool._internal import sanitize_header_value
509509

510510
# Safe header values should remain unchanged
511511
safe_values = [
@@ -519,7 +519,7 @@ def test_header_value_sanitization_rfc_compliant(self):
519519
]
520520

521521
for value in safe_values:
522-
result = _sanitize_header_value(value)
522+
result = sanitize_header_value(value)
523523
assert result == value
524524

525525
# Only truly dangerous characters should be removed
@@ -546,12 +546,12 @@ def test_header_value_sanitization_rfc_compliant(self):
546546
]
547547

548548
for input_val, expected in dangerous_cases:
549-
result = _sanitize_header_value(input_val)
549+
result = sanitize_header_value(input_val)
550550
assert result == expected
551551

552552
def test_header_value_preserves_rfc_folding(self):
553553
"""Test that legitimate CRLF sequences for header folding are preserved."""
554-
from google.adk.tools.mcp_tool._internal import _sanitize_header_value
554+
from google.adk.tools.mcp_tool._internal import sanitize_header_value
555555

556556
# Multi-line headers with proper folding should be preserved (RFC 7230 §3.2.4)
557557
folding_cases = [
@@ -565,7 +565,7 @@ def test_header_value_preserves_rfc_folding(self):
565565
]
566566

567567
for folding_case in folding_cases:
568-
result = _sanitize_header_value(folding_case[0])
568+
result = sanitize_header_value(folding_case[0])
569569
assert result == folding_case[1]
570570

571571
def test_header_value_validation_rfc_compliant(self):

0 commit comments

Comments
 (0)