Skip to content

Commit 3f9a2d1

Browse files
committed
feat(transport-security): add subdomain wildcard support for allowed_hosts
TransportSecuritySettings.allowed_hosts now supports *.domain patterns (e.g. *.mysite.com) so that a single entry can allow the base domain and any subdomain (app.mysite.com, api.mysite.com, etc.) instead of listing each host explicitly. This makes multi-subdomain or dynamic subdomain setups practical. - Add _hostname_from_host() to strip port from Host header (including IPv6) - In _validate_host(), treat entries starting with *. as subdomain wildcards: match hostname equal to base domain or ending with .<base> - Preserve existing behaviour: exact match and example.com:* port wildcard - Document the three pattern types in allowed_hosts docstring - Add integration tests for SSE and StreamableHTTP with *.mysite.com Github-Issue: #2141
1 parent 0fe16dd commit 3f9a2d1

File tree

4 files changed

+146
-3
lines changed

4 files changed

+146
-3
lines changed

src/mcp/server/transport_security.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ class TransportSecuritySettings(BaseModel):
2222
allowed_hosts: list[str] = Field(default_factory=list)
2323
"""List of allowed Host header values.
2424
25+
Supports:
26+
- Exact match: ``example.com``, ``127.0.0.1:8080``
27+
- Wildcard port: ``example.com:*`` matches ``example.com`` with any port
28+
- Subdomain wildcard: ``*.mysite.com`` matches ``mysite.com`` and any subdomain
29+
(e.g. ``app.mysite.com``, ``api.mysite.com``). Optionally use ``*.mysite.com:*``
30+
to also allow any port.
31+
2532
Only applies when `enable_dns_rebinding_protection` is `True`.
2633
"""
2734

@@ -40,6 +47,15 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
4047
# If not specified, disable DNS rebinding protection by default for backwards compatibility
4148
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)
4249

50+
def _hostname_from_host(self, host: str) -> str:
51+
"""Extract hostname from Host header (strip optional port)."""
52+
if host.startswith("["):
53+
idx = host.find("]:")
54+
if idx != -1:
55+
return host[: idx + 1]
56+
return host
57+
return host.split(":", 1)[0]
58+
4359
def _validate_host(self, host: str | None) -> bool: # pragma: no cover
4460
"""Validate the Host header against allowed values."""
4561
if not host:
@@ -50,15 +66,27 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover
5066
if host in self.settings.allowed_hosts:
5167
return True
5268

53-
# Check wildcard port patterns
69+
# Check wildcard port patterns (e.g. example.com:*)
5470
for allowed in self.settings.allowed_hosts:
5571
if allowed.endswith(":*"):
56-
# Extract base host from pattern
5772
base_host = allowed[:-2]
58-
# Check if the actual host starts with base host and has a port
73+
# Subdomain pattern *.domain.com:* is handled below; skip here
74+
if base_host.startswith("*."):
75+
continue
5976
if host.startswith(base_host + ":"):
6077
return True
6178

79+
# Check subdomain wildcard patterns (e.g. *.mysite.com or *.mysite.com:*)
80+
hostname = self._hostname_from_host(host)
81+
for allowed in self.settings.allowed_hosts:
82+
if allowed.startswith("*."):
83+
pattern = allowed[:-2] if allowed.endswith(":*") else allowed
84+
base_domain = pattern[2:]
85+
if not base_domain:
86+
continue
87+
if hostname == base_domain or hostname.endswith("." + base_domain):
88+
return True
89+
6290
logger.warning(f"Invalid Host header: {host}")
6391
return False
6492

tests/server/test_sse_security.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,59 @@ async def test_sse_security_wildcard_ports(server_port: int):
256256
process.join()
257257

258258

259+
@pytest.mark.anyio
260+
async def test_sse_security_ipv6_host_header(server_port: int):
261+
"""Test SSE with IPv6 Host header ([::1] and [::1]:port) to cover _hostname_from_host."""
262+
settings = TransportSecuritySettings(
263+
enable_dns_rebinding_protection=True,
264+
allowed_hosts=["127.0.0.1:*", "[::1]:*", "[::1]"],
265+
allowed_origins=["http://127.0.0.1:*", "http://[::1]:*"],
266+
)
267+
process = start_server_process(server_port, settings)
268+
269+
try:
270+
async with httpx.AsyncClient(timeout=5.0) as client:
271+
async with client.stream(
272+
"GET", f"http://127.0.0.1:{server_port}/sse", headers={"Host": "[::1]:8080"}
273+
) as response:
274+
assert response.status_code == 200
275+
async with client.stream(
276+
"GET", f"http://127.0.0.1:{server_port}/sse", headers={"Host": "[::1]"}
277+
) as response:
278+
assert response.status_code == 200
279+
finally:
280+
process.terminate()
281+
process.join()
282+
283+
284+
@pytest.mark.anyio
285+
async def test_sse_security_subdomain_wildcard_host(server_port: int):
286+
"""Test SSE with *.domain subdomain wildcard in allowed_hosts (issue #2141)."""
287+
settings = TransportSecuritySettings(
288+
enable_dns_rebinding_protection=True,
289+
allowed_hosts=["*.mysite.com", "127.0.0.1:*"],
290+
allowed_origins=["http://127.0.0.1:*", "http://app.mysite.com:*"],
291+
)
292+
process = start_server_process(server_port, settings)
293+
294+
try:
295+
# Allowed: subdomain and base domain
296+
for host in ["app.mysite.com", "api.mysite.com", "mysite.com"]:
297+
headers = {"Host": host}
298+
async with httpx.AsyncClient(timeout=5.0) as client:
299+
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
300+
assert response.status_code == 200, f"Host {host} should be allowed"
301+
302+
# Rejected: other domain
303+
async with httpx.AsyncClient() as client:
304+
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers={"Host": "other.com"})
305+
assert response.status_code == 421
306+
assert response.text == "Invalid Host header"
307+
finally:
308+
process.terminate()
309+
process.join()
310+
311+
259312
@pytest.mark.anyio
260313
async def test_sse_security_post_valid_content_type(server_port: int):
261314
"""Test POST endpoint with valid Content-Type headers."""

tests/server/test_streamable_http_security.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,47 @@ async def test_streamable_http_security_custom_allowed_hosts(server_port: int):
253253
process.join()
254254

255255

256+
@pytest.mark.anyio
257+
async def test_streamable_http_security_subdomain_wildcard_host(server_port: int):
258+
"""Test StreamableHTTP with *.domain subdomain wildcard in allowed_hosts (issue #2141)."""
259+
settings = TransportSecuritySettings(
260+
enable_dns_rebinding_protection=True,
261+
allowed_hosts=["*.mysite.com", "127.0.0.1:*"],
262+
allowed_origins=["http://127.0.0.1:*", "http://app.mysite.com:*"],
263+
)
264+
process = start_server_process(server_port, settings)
265+
266+
try:
267+
headers = {
268+
"Host": "app.mysite.com",
269+
"Accept": "application/json, text/event-stream",
270+
"Content-Type": "application/json",
271+
}
272+
async with httpx.AsyncClient(timeout=5.0) as client:
273+
response = await client.post(
274+
f"http://127.0.0.1:{server_port}/",
275+
json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}},
276+
headers=headers,
277+
)
278+
assert response.status_code == 200
279+
280+
async with httpx.AsyncClient() as client:
281+
response = await client.post(
282+
f"http://127.0.0.1:{server_port}/",
283+
json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}},
284+
headers={
285+
"Host": "other.com",
286+
"Accept": "application/json, text/event-stream",
287+
"Content-Type": "application/json",
288+
},
289+
)
290+
assert response.status_code == 421
291+
assert response.text == "Invalid Host header"
292+
finally:
293+
process.terminate()
294+
process.join()
295+
296+
256297
@pytest.mark.anyio
257298
async def test_streamable_http_security_get_request(server_port: int):
258299
"""Test StreamableHTTP GET request with security."""
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Tests for transport security (DNS rebinding protection)."""
2+
3+
from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings
4+
5+
6+
def test_hostname_from_host_ipv6_with_port():
7+
"""_hostname_from_host strips port from [::1]:port (coverage for lines 52-55)."""
8+
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
9+
assert m._hostname_from_host("[::1]:8080") == "[::1]"
10+
11+
12+
def test_hostname_from_host_ipv6_no_port():
13+
"""_hostname_from_host returns [::1] as-is when no port (coverage for line 56)."""
14+
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
15+
assert m._hostname_from_host("[::1]") == "[::1]"
16+
17+
18+
def test_hostname_from_host_plain_with_port():
19+
"""_hostname_from_host strips port from hostname (coverage for line 57)."""
20+
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
21+
assert m._hostname_from_host("app.mysite.com:8080") == "app.mysite.com"

0 commit comments

Comments
 (0)