-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtransport_security.py
More file actions
116 lines (87 loc) · 4.56 KB
/
transport_security.py
File metadata and controls
116 lines (87 loc) · 4.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""DNS rebinding protection for MCP server transports."""
import logging
from pydantic import BaseModel, Field
from starlette.requests import Request
from starlette.responses import Response
logger = logging.getLogger(__name__)
# TODO(Marcelo): We should flatten these settings. To be fair, I don't think we should even have this middleware.
class TransportSecuritySettings(BaseModel):
"""Settings for MCP transport security features.
These settings help protect against DNS rebinding attacks by validating incoming request headers.
"""
enable_dns_rebinding_protection: bool = True
"""Enable DNS rebinding protection (recommended for production)."""
allowed_hosts: list[str] = Field(default_factory=list)
"""List of allowed Host header values.
Only applies when `enable_dns_rebinding_protection` is `True`.
"""
allowed_origins: list[str] = Field(default_factory=list)
"""List of allowed Origin header values.
Only applies when `enable_dns_rebinding_protection` is `True`.
"""
# TODO(Marcelo): This should be a proper ASGI middleware. I'm sad to see this.
class TransportSecurityMiddleware:
"""Middleware to enforce DNS rebinding protection for MCP transport endpoints."""
def __init__(self, settings: TransportSecuritySettings | None = None):
# If not specified, disable DNS rebinding protection by default for backwards compatibility
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)
def _validate_host(self, host: str | None) -> bool: # pragma: lax no cover
"""Validate the Host header against allowed values."""
if not host:
logger.warning("Missing Host header in request")
return False
# Check exact match first
if host in self.settings.allowed_hosts:
return True
# Check wildcard port patterns
for allowed in self.settings.allowed_hosts:
if allowed.endswith(":*"):
# Extract base host from pattern
base_host = allowed[:-2]
# Check if the actual host starts with base host and has a port
if host.startswith(base_host + ":"):
return True
logger.warning(f"Invalid Host header: {host}")
return False
def _validate_origin(self, origin: str | None) -> bool: # pragma: lax no cover
"""Validate the Origin header against allowed values."""
# Origin can be absent for same-origin requests
if not origin:
return True
# Check exact match first
if origin in self.settings.allowed_origins:
return True
# Check wildcard port patterns
for allowed in self.settings.allowed_origins:
if allowed.endswith(":*"):
# Extract base origin from pattern
base_origin = allowed[:-2]
# Check if the actual origin starts with base origin and has a port
if origin.startswith(base_origin + ":"):
return True
logger.warning(f"Invalid Origin header: {origin}")
return False
def _validate_content_type(self, content_type: str | None) -> bool:
"""Validate the Content-Type header for POST requests."""
return content_type is not None and content_type.lower().startswith("application/json")
async def validate_request(self, request: Request, is_post: bool = False) -> Response | None:
"""Validate request headers for DNS rebinding protection.
Returns None if validation passes, or an error Response if validation fails.
"""
# Always validate Content-Type for POST requests
if is_post: # pragma: no branch
content_type = request.headers.get("content-type")
if not self._validate_content_type(content_type):
return Response("Invalid Content-Type header", status_code=400)
# Skip remaining validation if DNS rebinding protection is disabled
if not self.settings.enable_dns_rebinding_protection:
return None
# Validate Host header
host = request.headers.get("host") # pragma: lax no cover
if not self._validate_host(host): # pragma: lax no cover
return Response("Invalid Host header", status_code=421)
# Validate Origin header
origin = request.headers.get("origin") # pragma: lax no cover
if not self._validate_origin(origin): # pragma: lax no cover
return Response("Invalid Origin header", status_code=403)
return None # pragma: lax no cover