Skip to content

Commit a9cf1de

Browse files
authored
PYTHON-5814 Configurable DNS domain validation for SRV records (#2903)
1 parent 7cea267 commit a9cf1de

25 files changed

Lines changed: 16735 additions & 22 deletions

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ repos:
115115
# - test/versioned-api/crud-api-version-1-strict.json:514: nin ==> inn, min, bin, nine
116116
# - test/test_client.py:188: te ==> the, be, we, to
117117
args: ["-L", "fle,fo,infinit,isnt,nin,te,aks"]
118+
exclude: ^pymongo/public_suffix_list\.dat$
118119

119120
- repo: local
120121
hooks:

doc/changelog.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ Changelog
44
Changes in Version 4.18.0
55
-------------------------
66

7+
- Added the ``srvAllowedHostsSuffix`` URI option and :class:`~pymongo.mongo_client.MongoClient`
8+
keyword argument. When connecting via ``mongodb+srv://``, this option overrides the default
9+
requirement that SRV-returned hosts share the same parent domain as the seed hostname,
10+
allowing hosts under a different domain suffix to be accepted. The suffix must contain at
11+
least two labels and must not be a public suffix. See the
12+
:class:`~pymongo.mongo_client.MongoClient` documentation for security considerations.
713
- Improved TLS connection performance by reusing TLS sessions across connections
814
to the same server, avoiding a full handshake on each new connection.
915
Session resumption is supported on all Python versions for synchronous clients

pymongo/_psl.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2024-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you
4+
# may not use this file except in compliance with the License. You
5+
# may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12+
# implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
15+
"""Public Suffix List lookup for srvAllowedHostsSuffix validation."""
16+
17+
from __future__ import annotations
18+
19+
from pathlib import Path
20+
from typing import Optional
21+
22+
_PUBLIC_SUFFIXES: Optional[tuple[set[str], set[str], set[str]]] = None
23+
24+
25+
def _load_public_suffixes() -> tuple[set[str], set[str], set[str]]:
26+
path = Path(__file__).parent / "public_suffix_list.dat"
27+
suffixes: set[str] = set()
28+
wildcards: set[str] = set()
29+
exceptions: set[str] = set()
30+
with open(path, encoding="utf-8") as f:
31+
for line in f:
32+
line = line.strip() # noqa: PLW2901
33+
if not line or line.startswith("//"):
34+
continue
35+
if line.startswith("!"):
36+
exceptions.add(line[1:].lower())
37+
elif line.startswith("*."):
38+
wildcards.add(line[2:].lower())
39+
else:
40+
suffixes.add(line.lower())
41+
return suffixes, wildcards, exceptions
42+
43+
44+
def is_public_suffix(domain: str) -> bool:
45+
"""Return True if domain is a public suffix per the bundled Public Suffix List."""
46+
global _PUBLIC_SUFFIXES # noqa: PLW0603
47+
if _PUBLIC_SUFFIXES is None:
48+
_PUBLIC_SUFFIXES = _load_public_suffixes()
49+
suffixes, wildcards, exceptions = _PUBLIC_SUFFIXES
50+
51+
domain = domain.lower().strip(".")
52+
if domain in exceptions:
53+
return False
54+
if domain in suffixes:
55+
return True
56+
parts = domain.split(".")
57+
return len(parts) > 1 and ".".join(parts[1:]) in wildcards

pymongo/asynchronous/mongo_client.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,30 @@ def __init__(
450450
connect to. More specifically, when a "mongodb+srv://" connection string
451451
resolves to more than srvMaxHosts number of hosts, the client will randomly
452452
choose an srvMaxHosts sized subset of hosts.
453+
- `srvAllowedHostsSuffix`: (string) Overrides the default requirement that
454+
hosts returned by SRV DNS records share the same parent domain as the seed
455+
hostname. When set, the driver accepts any returned host whose name ends
456+
with this suffix (e.g. ``".atlas.mongodb.com"``). The value must contain
457+
at least two labels and must not be a public suffix (per the Public Suffix
458+
List). Only valid with ``mongodb+srv://`` URIs.
459+
460+
.. warning::
461+
462+
This option relaxes a built-in DNS spoofing safeguard. Use the most
463+
specific suffix possible for your deployment rather than a broad
464+
company-wide domain. For example, instead of::
465+
466+
AsyncMongoClient(
467+
"mongodb+srv://cluster.test.internal.example.com/",
468+
srvAllowedHostsSuffix=".example.com",
469+
)
470+
471+
which would accept any host across the entire domain, scope it further like so::
472+
473+
AsyncMongoClient(
474+
"mongodb+srv://cluster.test.internal.example.com/",
475+
srvAllowedHostsSuffix=".internal.example.com",
476+
)
453477
454478
455479
| **Write Concern options:**
@@ -803,6 +827,7 @@ def __init__(
803827
fqdn = None
804828
srv_service_name = keyword_opts.get("srvservicename")
805829
srv_max_hosts = keyword_opts.get("srvmaxhosts")
830+
srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix")
806831
if len([h for h in self._host if "/" in h]) > 1:
807832
raise ConfigurationError("host must not contain multiple MongoDB URIs")
808833
for entity in self._host:
@@ -853,6 +878,8 @@ def __init__(
853878
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
854879

855880
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
881+
if srv_allowed_hosts_suffix is None:
882+
srv_allowed_hosts_suffix = opts.get("srvallowedhostssuffix")
856883
opts = self._normalize_and_validate_options(opts, self._seeds)
857884

858885
# Username and password passed as kwargs override user info in URI.
@@ -890,7 +917,9 @@ def __init__(
890917

891918
self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries)
892919

893-
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
920+
self._init_based_on_options(
921+
self._seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix
922+
)
894923

895924
self._opened = False
896925
self._closed = False
@@ -908,6 +937,7 @@ async def _resolve_srv(self) -> None:
908937
opts = common._CaseInsensitiveDictionary()
909938
srv_service_name = keyword_opts.get("srvservicename")
910939
srv_max_hosts = keyword_opts.get("srvmaxhosts")
940+
srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix")
911941
for entity in self._host:
912942
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
913943
# it must be a URI,
@@ -928,6 +958,7 @@ async def _resolve_srv(self) -> None:
928958
connect_timeout=timeout,
929959
srv_service_name=srv_service_name,
930960
srv_max_hosts=srv_max_hosts,
961+
srv_allowed_hosts_suffix=srv_allowed_hosts_suffix,
931962
)
932963
seeds.update(res["nodelist"])
933964
opts = res["options"]
@@ -960,6 +991,8 @@ async def _resolve_srv(self) -> None:
960991
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
961992

962993
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
994+
if srv_allowed_hosts_suffix is None:
995+
srv_allowed_hosts_suffix = opts.get("srvAllowedHostsSuffix")
963996
opts = self._normalize_and_validate_options(opts, seeds)
964997

965998
# Username and password passed as kwargs override user info in URI.
@@ -969,10 +1002,16 @@ async def _resolve_srv(self) -> None:
9691002
username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC
9701003
)
9711004

972-
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
1005+
self._init_based_on_options(
1006+
seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix
1007+
)
9731008

9741009
def _init_based_on_options(
975-
self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any
1010+
self,
1011+
seeds: Collection[tuple[str, int]],
1012+
srv_max_hosts: Any,
1013+
srv_service_name: Any,
1014+
srv_allowed_hosts_suffix: Any,
9761015
) -> None:
9771016
self._event_listeners = self._options.pool_options._event_listeners
9781017
self._topology_settings = TopologySettings(
@@ -991,6 +1030,7 @@ def _init_based_on_options(
9911030
load_balanced=self._options.load_balanced,
9921031
srv_service_name=srv_service_name,
9931032
srv_max_hosts=srv_max_hosts,
1033+
srv_allowed_hosts_suffix=srv_allowed_hosts_suffix,
9941034
server_monitoring_mode=self._options.server_monitoring_mode,
9951035
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
9961036
)

pymongo/asynchronous/monitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
418418
self._fqdn,
419419
self._settings.pool_options.connect_timeout,
420420
self._settings.srv_service_name,
421+
srv_allowed_hosts_suffix=self._settings.srv_allowed_hosts_suffix,
421422
)
422423
seedlist, ttl = await resolver.get_hosts_and_min_ttl()
423424
if len(seedlist) == 0:

pymongo/asynchronous/settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
load_balanced: Optional[bool] = None,
5353
srv_service_name: str = common.SRV_SERVICE_NAME,
5454
srv_max_hosts: int = 0,
55+
srv_allowed_hosts_suffix: Optional[str] = None,
5556
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
5657
topology_id: Optional[ObjectId] = None,
5758
):
@@ -79,6 +80,7 @@ def __init__(
7980
self._load_balanced = load_balanced
8081
self._srv_service_name = srv_service_name
8182
self._srv_max_hosts = srv_max_hosts or 0
83+
self._srv_allowed_hosts_suffix = srv_allowed_hosts_suffix
8284
self._server_monitoring_mode = server_monitoring_mode
8385
if topology_id is not None:
8486
self._topology_id = topology_id
@@ -156,6 +158,11 @@ def srv_max_hosts(self) -> int:
156158
"""The srvMaxHosts."""
157159
return self._srv_max_hosts
158160

161+
@property
162+
def srv_allowed_hosts_suffix(self) -> Optional[str]:
163+
"""The srvAllowedHostsSuffix."""
164+
return self._srv_allowed_hosts_suffix
165+
159166
@property
160167
def server_monitoring_mode(self) -> str:
161168
"""The serverMonitoringMode."""

pymongo/asynchronous/srv_resolver.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import random
2121
from typing import TYPE_CHECKING, Any, Optional, Union
2222

23+
from pymongo._psl import is_public_suffix
2324
from pymongo.common import CONNECT_TIMEOUT
2425
from pymongo.errors import ConfigurationError
2526

@@ -71,11 +72,29 @@ def __init__(
7172
connect_timeout: Optional[float],
7273
srv_service_name: str,
7374
srv_max_hosts: int = 0,
75+
srv_allowed_hosts_suffix: Optional[str] = None,
7476
):
75-
self.__fqdn = fqdn
77+
self.__fqdn = fqdn.lower()
7678
self.__srv = srv_service_name
7779
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
7880
self.__srv_max_hosts = srv_max_hosts or 0
81+
self.__srv_allowed_hosts_suffix = (
82+
"." + srv_allowed_hosts_suffix.lower().strip(".") if srv_allowed_hosts_suffix else None
83+
) # ensure there's a . at the beginning of the domain
84+
if (
85+
self.__srv_allowed_hosts_suffix is not None
86+
and "." not in self.__srv_allowed_hosts_suffix[1:]
87+
):
88+
raise ConfigurationError(
89+
"srvAllowedHostsSuffix must contain at least two labels (e.g. '.mydomain.net'), "
90+
f"got: {srv_allowed_hosts_suffix}"
91+
)
92+
if self.__srv_allowed_hosts_suffix is not None and is_public_suffix(
93+
self.__srv_allowed_hosts_suffix
94+
):
95+
raise ConfigurationError(
96+
f"srvAllowedHostsSuffix must not be a public suffix, got: {srv_allowed_hosts_suffix}"
97+
)
7998
# Validate the fully qualified domain name.
8099
try:
81100
ipaddress.ip_address(fqdn)
@@ -135,12 +154,16 @@ async def _get_srv_response_and_hosts(
135154
raise ConfigurationError(
136155
"Invalid SRV host: return address is identical to SRV hostname"
137156
)
138-
try:
139-
nlist = srv_host.split(".")[1:][-self.__slen :]
140-
except Exception as exc:
141-
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc
142-
if self.__plist != nlist:
143-
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
157+
if self.__srv_allowed_hosts_suffix is not None:
158+
if not srv_host.endswith(self.__srv_allowed_hosts_suffix):
159+
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
160+
else:
161+
try:
162+
nlist = srv_host.split(".")[1:][-self.__slen :]
163+
except Exception as exc:
164+
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc
165+
if self.__plist != nlist:
166+
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
144167
if self.__srv_max_hosts:
145168
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
146169
return results, nodes

pymongo/asynchronous/uri_parser.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ async def parse_uri(
4848
connect_timeout: Optional[float] = None,
4949
srv_service_name: Optional[str] = None,
5050
srv_max_hosts: Optional[int] = None,
51+
srv_allowed_hosts_suffix: Optional[str] = None,
5152
) -> dict[str, Any]:
5253
"""Parse and validate a MongoDB URI.
5354
@@ -116,6 +117,7 @@ async def parse_uri(
116117
connect_timeout,
117118
srv_service_name,
118119
srv_max_hosts,
120+
srv_allowed_hosts_suffix,
119121
)
120122
)
121123
result["options"] = _make_options_case_sensitive(result["options"])
@@ -131,6 +133,7 @@ async def _parse_srv(
131133
connect_timeout: Optional[float] = None,
132134
srv_service_name: Optional[str] = None,
133135
srv_max_hosts: Optional[int] = None,
136+
srv_allowed_hosts_suffix: Optional[str] = None,
134137
) -> dict[str, Any]:
135138
if uri.startswith(SCHEME):
136139
is_srv = False
@@ -158,14 +161,17 @@ async def _parse_srv(
158161

159162
hosts = unquote_plus(hosts)
160163
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
164+
srv_allowed_hosts_suffix = srv_allowed_hosts_suffix or options.get("srvAllowedHostsSuffix")
161165
if is_srv:
162166
nodes = split_hosts(hosts, default_port=None)
163167
fqdn, _port = nodes[0]
164168

165169
# Use the connection timeout. connectTimeoutMS passed as a keyword
166170
# argument overrides the same option passed in the connection string.
167171
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
168-
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
172+
dns_resolver = _SrvResolver(
173+
fqdn, connect_timeout, srv_service_name, srv_max_hosts, srv_allowed_hosts_suffix
174+
)
169175
nodes = await dns_resolver.get_hosts()
170176
dns_options = await dns_resolver.get_options()
171177
if dns_options:

pymongo/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,7 @@ def validate_server_monitoring_mode(option: str, value: str) -> str:
718718
"zlibcompressionlevel": validate_zlib_compression_level,
719719
"srvservicename": validate_string,
720720
"srvmaxhosts": validate_non_negative_integer,
721+
"srvallowedhostssuffix": validate_string,
721722
"timeoutms": validate_timeoutms,
722723
"servermonitoringmode": validate_server_monitoring_mode,
723724
"maxadaptiveretries": validate_non_negative_integer,

0 commit comments

Comments
 (0)