Skip to content

Commit 9db0910

Browse files
committed
fix(fetch): add DNS rebinding TOCTOU protection via SSRFSafeTransport
Address review feedback on SSRF protection: - Add SSRFSafeTransport custom async transport that resolves DNS, validates the resolved IP, and replaces the hostname with the validated IP before connecting. This eliminates the TOCTOU window between validate_url_for_ssrf() and the actual HTTP request. - Integrate SSRFSafeTransport into fetch_url() and check_may_autonomously_fetch_url() replacing direct AsyncClient usage. - Add 6 DNS rebinding tests including full attack scenario simulation. - Update existing tests to match new transport-based architecture.
1 parent 4140538 commit 9db0910

3 files changed

Lines changed: 245 additions & 17 deletions

File tree

src/fetch/src/mcp_server_fetch/server.py

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import os
33
import socket
44
import ssl
5+
from types import TracebackType
56
from typing import Annotated, Tuple
67
from urllib.parse import urlparse, urlunparse
78

9+
import httpx
810
import markdownify
911
import readabilipy.simple_json
1012
from mcp.shared.exceptions import McpError
@@ -237,9 +239,9 @@ def validate_url_for_ssrf(url: str) -> None:
237239
McpError: If the URL is potentially dangerous
238240
239241
Security Note:
240-
This validation happens BEFORE the request is made, but DNS rebinding
241-
attacks could still occur. For maximum security, use network-level
242-
controls (firewall rules, egress filtering).
242+
This validation provides early rejection of obviously dangerous URLs.
243+
DNS rebinding protection is handled at the transport layer by
244+
SSRFSafeTransport, which validates resolved IPs at connection time.
243245
"""
244246
try:
245247
parsed = urlparse(url)
@@ -332,6 +334,94 @@ def validate_url_for_ssrf(url: str) -> None:
332334
))
333335

334336

337+
class SSRFSafeTransport(httpx.AsyncBaseTransport):
338+
"""
339+
Custom async transport that prevents DNS rebinding attacks.
340+
341+
DNS rebinding TOCTOU (Time-of-Check-Time-of-Use) attack:
342+
1. validate_url_for_ssrf() resolves DNS → gets public IP → passes check
343+
2. Attacker's DNS server changes the record to a private IP (e.g., 169.254.169.254)
344+
3. httpx resolves DNS again → gets private IP → connects to internal service
345+
346+
This transport eliminates the TOCTOU window by:
347+
1. Resolving DNS ourselves
348+
2. Validating the resolved IP
349+
3. Replacing the hostname in the URL with the validated IP
350+
4. Preserving the original Host header for correct HTTP routing
351+
"""
352+
353+
def __init__(self, proxy: str | None = None, verify: bool = True):
354+
kwargs: dict = {"verify": verify}
355+
if proxy:
356+
kwargs["proxy"] = proxy
357+
self._transport = httpx.AsyncHTTPTransport(**kwargs)
358+
359+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
360+
hostname = request.url.host
361+
# Skip IP validation for already-resolved IPs
362+
try:
363+
ipaddress.ip_address(hostname)
364+
# Already an IP - validation was done in validate_url_for_ssrf()
365+
return await self._transport.handle_async_request(request)
366+
except ValueError:
367+
pass # It's a hostname, resolve it
368+
369+
# Resolve DNS
370+
try:
371+
addr_info = socket.getaddrinfo(
372+
hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM
373+
)
374+
if not addr_info:
375+
raise McpError(ErrorData(
376+
code=INVALID_PARAMS,
377+
message=f"Failed to resolve hostname '{hostname}': no addresses found",
378+
))
379+
resolved_ip = addr_info[0][4][0]
380+
except socket.gaierror as e:
381+
raise McpError(ErrorData(
382+
code=INVALID_PARAMS,
383+
message=f"Failed to resolve hostname '{hostname}': {str(e)}",
384+
))
385+
386+
# Validate resolved IP against SSRF rules
387+
if not ALLOW_PRIVATE_IPS and _is_ip_private_or_reserved(resolved_ip):
388+
raise McpError(ErrorData(
389+
code=INVALID_PARAMS,
390+
message=f"DNS rebinding protection: hostname '{hostname}' resolved to "
391+
f"private/internal IP '{resolved_ip}' at connection time. "
392+
f"Set MCP_FETCH_ALLOW_PRIVATE_IPS=true to allow internal network access.",
393+
))
394+
395+
# Replace hostname with validated IP to prevent DNS rebinding
396+
# The Host header is already set to the original hostname by httpx
397+
new_url = request.url.copy_with(host=resolved_ip)
398+
# Create new request with the IP-based URL but same headers (including Host)
399+
new_request = httpx.Request(
400+
method=request.method,
401+
url=new_url,
402+
headers=request.headers,
403+
stream=request.stream,
404+
extensions=request.extensions,
405+
)
406+
407+
return await self._transport.handle_async_request(new_request)
408+
409+
async def aclose(self):
410+
await self._transport.aclose()
411+
412+
async def __aenter__(self):
413+
await self._transport.__aenter__()
414+
return self
415+
416+
async def __aexit__(
417+
self,
418+
exc_type: type[BaseException] | None = None,
419+
exc_val: BaseException | None = None,
420+
exc_tb: TracebackType | None = None,
421+
) -> None:
422+
await self._transport.__aexit__(exc_type, exc_val, exc_tb)
423+
424+
335425
def extract_content_from_html(html: str) -> str:
336426
"""Extract and convert HTML content to Markdown format.
337427
@@ -381,14 +471,13 @@ async def check_may_autonomously_fetch_url(url: str, user_agent: str, proxy_url:
381471
- SSL certificate verification (configurable via SSL_VERIFY)
382472
- Comprehensive SSL error handling
383473
"""
384-
import httpx
385-
386474
robot_txt_url = get_robots_txt_url(url)
387475

388476
# SSRF Protection: Validate robots.txt URL before fetching
389477
validate_url_for_ssrf(robot_txt_url)
390478

391-
async with httpx.AsyncClient(proxy=proxy_url, verify=SSL_VERIFY) as client:
479+
transport = SSRFSafeTransport(proxy=proxy_url, verify=SSL_VERIFY)
480+
async with httpx.AsyncClient(transport=transport) as client:
392481
try:
393482
response = await client.get(
394483
robot_txt_url,
@@ -461,12 +550,11 @@ async def fetch_url(
461550
- User-Agent header for transparency
462551
- Comprehensive SSL error handling (catches wrapped exceptions)
463552
"""
464-
import httpx
465-
466553
# SSRF Protection: Validate URL before fetching
467554
validate_url_for_ssrf(url)
468555

469-
async with httpx.AsyncClient(proxy=proxy_url, verify=SSL_VERIFY) as client:
556+
transport = SSRFSafeTransport(proxy=proxy_url, verify=SSL_VERIFY)
557+
async with httpx.AsyncClient(transport=transport) as client:
470558
try:
471559
response = await client.get(
472560
url,

src/fetch/tests/test_security.py

Lines changed: 142 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_parse_obfuscated_ip,
2626
fetch_url,
2727
extract_content_from_html,
28+
SSRFSafeTransport,
2829
BLOCKED_HOSTNAMES,
2930
CLOUD_METADATA_IPS,
3031
)
@@ -134,8 +135,9 @@ async def test_ssl_disabled_allows_self_signed(self, reset_env):
134135
import mcp_server_fetch.server as server_module
135136
importlib.reload(server_module)
136137

137-
# Mock httpx.AsyncClient to verify verify=False is passed
138-
with patch('httpx.AsyncClient') as mock_client:
138+
# Mock httpx.AsyncClient and AsyncHTTPTransport to verify verify=False is passed
139+
with patch('httpx.AsyncClient') as mock_client, \
140+
patch('httpx.AsyncHTTPTransport') as mock_transport_class:
139141
mock_response = MagicMock()
140142
mock_response.status_code = 200
141143
mock_response.text = "<html><body>Test</body></html>"
@@ -155,9 +157,9 @@ async def test_ssl_disabled_allows_self_signed(self, reset_env):
155157
"TestAgent/1.0"
156158
)
157159

158-
# Verify AsyncClient was called with verify=False
159-
mock_client.assert_called_once()
160-
call_kwargs = mock_client.call_args[1]
160+
# Verify AsyncHTTPTransport was created with verify=False
161+
mock_transport_class.assert_called_once()
162+
call_kwargs = mock_transport_class.call_args[1]
161163
assert call_kwargs.get('verify') is False
162164

163165

@@ -646,6 +648,141 @@ def test_url_with_port_bypass_attempt(self):
646648
validate_url_for_ssrf("http://127.0.0.1:65535/")
647649

648650

651+
# =============================================================================
652+
# 8. DNS REBINDING PROTECTION TESTS
653+
# =============================================================================
654+
655+
class TestDNSRebindingProtection:
656+
"""Test suite for DNS rebinding TOCTOU protection via SSRFSafeTransport."""
657+
658+
@pytest.mark.asyncio
659+
async def test_transport_blocks_private_ip_at_connection_time(self):
660+
"""SSRFSafeTransport must block requests when DNS resolves to private IP."""
661+
import httpx
662+
663+
transport = SSRFSafeTransport(verify=False)
664+
665+
# Simulate DNS resolving to a private IP (127.0.0.1)
666+
with patch("socket.getaddrinfo") as mock_dns:
667+
mock_dns.return_value = [
668+
(2, 1, 6, '', ('127.0.0.1', 0)),
669+
]
670+
request = httpx.Request("GET", "http://evil-rebind.example.com/secret")
671+
672+
with pytest.raises(McpError, match="DNS rebinding protection"):
673+
await transport.handle_async_request(request)
674+
675+
@pytest.mark.asyncio
676+
async def test_transport_blocks_metadata_ip_at_connection_time(self):
677+
"""SSRFSafeTransport must block DNS rebinding to cloud metadata IP."""
678+
import httpx
679+
680+
transport = SSRFSafeTransport(verify=False)
681+
682+
# Simulate DNS rebinding: attacker DNS returns metadata IP
683+
with patch("socket.getaddrinfo") as mock_dns:
684+
mock_dns.return_value = [
685+
(2, 1, 6, '', ('169.254.169.254', 0)),
686+
]
687+
request = httpx.Request("GET", "http://evil-rebind.example.com/metadata")
688+
689+
with pytest.raises(McpError, match="DNS rebinding protection"):
690+
await transport.handle_async_request(request)
691+
692+
@pytest.mark.asyncio
693+
async def test_transport_allows_public_ip(self):
694+
"""SSRFSafeTransport must allow requests when DNS resolves to public IP."""
695+
import httpx
696+
697+
transport = SSRFSafeTransport(verify=False)
698+
699+
# Simulate DNS resolving to a public IP
700+
with patch("socket.getaddrinfo") as mock_dns, \
701+
patch.object(transport, '_transport') as mock_inner:
702+
mock_dns.return_value = [
703+
(2, 1, 6, '', ('93.184.216.34', 0)),
704+
]
705+
mock_response = httpx.Response(200, text="OK")
706+
mock_inner.handle_async_request = AsyncMock(return_value=mock_response)
707+
708+
request = httpx.Request("GET", "http://example.com/page")
709+
response = await transport.handle_async_request(request)
710+
711+
assert response.status_code == 200
712+
# Verify the inner transport was called with the IP-based URL
713+
called_request = mock_inner.handle_async_request.call_args[0][0]
714+
assert called_request.url.host == "93.184.216.34"
715+
# Verify Host header preserved
716+
assert called_request.headers["host"] == "example.com"
717+
718+
@pytest.mark.asyncio
719+
async def test_transport_skips_validation_for_direct_ip(self):
720+
"""SSRFSafeTransport should skip DNS resolution for direct IP URLs."""
721+
import httpx
722+
723+
transport = SSRFSafeTransport(verify=False)
724+
725+
# Direct IP URL - should go straight to inner transport (IP already validated by validate_url_for_ssrf)
726+
with patch.object(transport, '_transport') as mock_inner, \
727+
patch("socket.getaddrinfo") as mock_dns:
728+
mock_response = httpx.Response(200, text="OK")
729+
mock_inner.handle_async_request = AsyncMock(return_value=mock_response)
730+
731+
request = httpx.Request("GET", "http://93.184.216.34/page")
732+
await transport.handle_async_request(request)
733+
734+
# DNS should NOT be called for direct IP
735+
mock_dns.assert_not_called()
736+
mock_inner.handle_async_request.assert_called_once()
737+
738+
@pytest.mark.asyncio
739+
async def test_transport_blocks_dns_failure(self):
740+
"""SSRFSafeTransport must raise error when DNS resolution fails."""
741+
import httpx
742+
import socket as socket_module
743+
744+
transport = SSRFSafeTransport(verify=False)
745+
746+
with patch("socket.getaddrinfo") as mock_dns:
747+
mock_dns.side_effect = socket_module.gaierror("Name resolution failed")
748+
request = httpx.Request("GET", "http://nonexistent.example.com/")
749+
750+
with pytest.raises(McpError, match="Failed to resolve"):
751+
await transport.handle_async_request(request)
752+
753+
@pytest.mark.asyncio
754+
async def test_dns_rebinding_scenario(self):
755+
"""
756+
Full DNS rebinding attack scenario:
757+
1. validate_url_for_ssrf() sees public IP (passes)
758+
2. SSRFSafeTransport resolves DNS again and sees private IP (blocks)
759+
"""
760+
import httpx
761+
762+
call_count = 0
763+
764+
def rebinding_dns(hostname, *args, **kwargs):
765+
nonlocal call_count
766+
call_count += 1
767+
if call_count == 1:
768+
# First call (validate_url_for_ssrf): return public IP
769+
return [(2, 1, 6, '', ('93.184.216.34', 0))]
770+
else:
771+
# Second call (SSRFSafeTransport): return private IP (rebinding!)
772+
return [(2, 1, 6, '', ('169.254.169.254', 0))]
773+
774+
with patch("socket.getaddrinfo", side_effect=rebinding_dns):
775+
# First validation passes (public IP)
776+
validate_url_for_ssrf("http://evil-rebind.example.com/")
777+
778+
# But transport-level check catches the rebinding
779+
transport = SSRFSafeTransport(verify=False)
780+
request = httpx.Request("GET", "http://evil-rebind.example.com/metadata")
781+
782+
with pytest.raises(McpError, match="DNS rebinding protection"):
783+
await transport.handle_async_request(request)
784+
785+
649786
# =============================================================================
650787
# RUN CONFIGURATION
651788
# =============================================================================

src/fetch/tests/test_server.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,13 +305,14 @@ async def test_fetch_500_raises_error(self):
305305

306306
@pytest.mark.asyncio
307307
async def test_fetch_with_proxy(self):
308-
"""Test that proxy URL is passed to client."""
308+
"""Test that proxy URL is passed to SSRFSafeTransport."""
309309
mock_response = MagicMock()
310310
mock_response.status_code = 200
311311
mock_response.text = '{"data": "test"}'
312312
mock_response.headers = {"content-type": "application/json"}
313313

314314
with patch("httpx.AsyncClient") as mock_client_class, \
315+
patch("httpx.AsyncHTTPTransport") as mock_transport_class, \
315316
patch("mcp_server_fetch.server.validate_url_for_ssrf"), \
316317
patch("mcp_server_fetch.server.SSL_VERIFY", True):
317318
mock_client = AsyncMock()
@@ -325,5 +326,7 @@ async def test_fetch_with_proxy(self):
325326
proxy_url="http://proxy.example.com:8080"
326327
)
327328

328-
# Verify AsyncClient was called with proxy and verify
329-
mock_client_class.assert_called_once_with(proxy="http://proxy.example.com:8080", verify=True)
329+
# Verify AsyncHTTPTransport was created with proxy and verify
330+
mock_transport_class.assert_called_once_with(
331+
verify=True, proxy="http://proxy.example.com:8080"
332+
)

0 commit comments

Comments
 (0)