-
Notifications
You must be signed in to change notification settings - Fork 1.2k
PYTHON-5814 Configurable DNS domain validation for SRV records #2903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
2cd02da
499a3ec
3423d39
3e30434
1ac4967
a8a3b4e
9663485
a71dd7a
4a6ba01
466a47e
7902127
c8f8c9f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| # Copyright 2024-present MongoDB, Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); you | ||
| # may not use this file except in compliance with the License. You | ||
| # may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||
| # implied. See the License for the specific language governing | ||
| # permissions and limitations under the License. | ||
|
|
||
| """Public Suffix List lookup for srvAllowedHostsSuffix validation.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from pathlib import Path | ||
|
|
||
|
|
||
| def _load_public_suffixes() -> tuple[set[str], set[str], set[str]]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It could be a good idea to do a simple cache this information so that it doesn't do a full file read each time we check for a public suffix.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh good call, done in (7902127) |
||
| path = Path(__file__).parent / "public_suffix_list.dat" | ||
| suffixes: set[str] = set() | ||
| wildcards: set[str] = set() | ||
| exceptions: set[str] = set() | ||
| with open(path, encoding="utf-8") as f: | ||
| for line in f: | ||
| line = line.strip() # noqa: PLW2901 | ||
| if not line or line.startswith("//"): | ||
| continue | ||
| if line.startswith("!"): | ||
| exceptions.add(line[1:].lower()) | ||
| elif line.startswith("*."): | ||
| wildcards.add(line[2:].lower()) | ||
| else: | ||
| suffixes.add(line.lower()) | ||
| return suffixes, wildcards, exceptions | ||
|
|
||
|
|
||
| def is_public_suffix(domain: str) -> bool: | ||
| """Return True if domain is a public suffix per the bundled Public Suffix List.""" | ||
| suffixes, wildcards, exceptions = _load_public_suffixes() | ||
|
|
||
| domain = domain.lower().strip(".") | ||
| if domain in exceptions: | ||
| return False | ||
| if domain in suffixes: | ||
| return True | ||
| parts = domain.split(".") | ||
| return len(parts) > 1 and ".".join(parts[1:]) in wildcards | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -450,6 +450,18 @@ def __init__( | |
| connect to. More specifically, when a "mongodb+srv://" connection string | ||
| resolves to more than srvMaxHosts number of hosts, the client will randomly | ||
| choose an srvMaxHosts sized subset of hosts. | ||
| - `srvAllowedHostsSuffix`: (string) Overrides the default requirement that | ||
| hosts returned by SRV DNS records share the same parent domain as the seed | ||
| hostname. When set, the driver accepts any returned host whose name ends | ||
| with this suffix (e.g. ``".atlas.mongodb.com"``). The value must contain | ||
| at least two labels and must not be a public suffix (per the Public Suffix | ||
| List). Only valid with ``mongodb+srv://`` URIs. | ||
|
|
||
| .. warning:: | ||
|
|
||
| This option relaxes a built-in DNS spoofing safeguard. Use the most | ||
| specific suffix possible for your deployment rather than a broad | ||
| company-wide domain. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add an example of a good configuration just so the users can have an explicit understanding of what to do?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done in c8f8c9f! |
||
|
|
||
|
|
||
| | **Write Concern options:** | ||
|
|
@@ -802,6 +814,7 @@ def __init__( | |
| fqdn = None | ||
| srv_service_name = keyword_opts.get("srvservicename") | ||
| srv_max_hosts = keyword_opts.get("srvmaxhosts") | ||
| srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix") | ||
| if len([h for h in self._host if "/" in h]) > 1: | ||
| raise ConfigurationError("host must not contain multiple MongoDB URIs") | ||
| for entity in self._host: | ||
|
|
@@ -852,6 +865,8 @@ def __init__( | |
| srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) | ||
|
|
||
| srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") | ||
| if srv_allowed_hosts_suffix is None: | ||
| srv_allowed_hosts_suffix = opts.get("srvallowedhostssuffix") | ||
| opts = self._normalize_and_validate_options(opts, self._seeds) | ||
|
|
||
| # Username and password passed as kwargs override user info in URI. | ||
|
|
@@ -889,7 +904,9 @@ def __init__( | |
|
|
||
| self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries) | ||
|
|
||
| self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) | ||
| self._init_based_on_options( | ||
| self._seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix | ||
| ) | ||
|
|
||
| self._opened = False | ||
| self._closed = False | ||
|
|
@@ -907,6 +924,7 @@ async def _resolve_srv(self) -> None: | |
| opts = common._CaseInsensitiveDictionary() | ||
| srv_service_name = keyword_opts.get("srvservicename") | ||
| srv_max_hosts = keyword_opts.get("srvmaxhosts") | ||
| srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix") | ||
| for entity in self._host: | ||
| # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' | ||
| # it must be a URI, | ||
|
|
@@ -927,6 +945,7 @@ async def _resolve_srv(self) -> None: | |
| connect_timeout=timeout, | ||
| srv_service_name=srv_service_name, | ||
| srv_max_hosts=srv_max_hosts, | ||
| srv_allowed_hosts_suffix=srv_allowed_hosts_suffix, | ||
| ) | ||
| seeds.update(res["nodelist"]) | ||
| opts = res["options"] | ||
|
|
@@ -959,6 +978,8 @@ async def _resolve_srv(self) -> None: | |
| srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) | ||
|
|
||
| srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") | ||
| if srv_allowed_hosts_suffix is None: | ||
| srv_allowed_hosts_suffix = opts.get("srvAllowedHostsSuffix") | ||
| opts = self._normalize_and_validate_options(opts, seeds) | ||
|
|
||
| # Username and password passed as kwargs override user info in URI. | ||
|
|
@@ -968,10 +989,16 @@ async def _resolve_srv(self) -> None: | |
| username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC | ||
| ) | ||
|
|
||
| self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) | ||
| self._init_based_on_options( | ||
| seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix | ||
| ) | ||
|
|
||
| def _init_based_on_options( | ||
| self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any | ||
| self, | ||
| seeds: Collection[tuple[str, int]], | ||
| srv_max_hosts: Any, | ||
| srv_service_name: Any, | ||
| srv_allowed_hosts_suffix: Any, | ||
| ) -> None: | ||
| self._event_listeners = self._options.pool_options._event_listeners | ||
| self._topology_settings = TopologySettings( | ||
|
|
@@ -990,6 +1017,7 @@ def _init_based_on_options( | |
| load_balanced=self._options.load_balanced, | ||
| srv_service_name=srv_service_name, | ||
| srv_max_hosts=srv_max_hosts, | ||
| srv_allowed_hosts_suffix=srv_allowed_hosts_suffix, | ||
| server_monitoring_mode=self._options.server_monitoring_mode, | ||
| topology_id=self._topology_settings._topology_id if self._topology_settings else None, | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| import random | ||
| from typing import TYPE_CHECKING, Any, Optional, Union | ||
|
|
||
| from pymongo._psl import is_public_suffix | ||
| from pymongo.common import CONNECT_TIMEOUT | ||
| from pymongo.errors import ConfigurationError | ||
|
|
||
|
|
@@ -71,11 +72,29 @@ def __init__( | |
| connect_timeout: Optional[float], | ||
| srv_service_name: str, | ||
| srv_max_hosts: int = 0, | ||
| srv_allowed_hosts_suffix: Optional[str] = None, | ||
| ): | ||
| self.__fqdn = fqdn | ||
| self.__fqdn = fqdn.lower() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. errr not necessary, i think my bot(s) noticed it an existing bug -- because I added a test case about case insensitivity for the new parameter i'm introducing -- for example if a user passed in a URI like "mongodb+srv://test1.TEST.BUILD.10GEN.CC/" instead of "mongodb+srv://test1.test.build.10gen.cc/" the current logic (without the |
||
| self.__srv = srv_service_name | ||
| self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT | ||
| self.__srv_max_hosts = srv_max_hosts or 0 | ||
| self.__srv_allowed_hosts_suffix = ( | ||
| "." + srv_allowed_hosts_suffix.lower().strip(".") if srv_allowed_hosts_suffix else None | ||
| ) # ensure there's a . at the beginning of the domain | ||
| if ( | ||
| self.__srv_allowed_hosts_suffix is not None | ||
| and "." not in self.__srv_allowed_hosts_suffix[1:] | ||
| ): | ||
| raise ConfigurationError( | ||
| "srvAllowedHostsSuffix must contain at least two labels (e.g. '.mydomain.net'), " | ||
| f"got: {srv_allowed_hosts_suffix}" | ||
| ) | ||
| if self.__srv_allowed_hosts_suffix is not None and is_public_suffix( | ||
| self.__srv_allowed_hosts_suffix | ||
| ): | ||
| raise ConfigurationError( | ||
| f"srvAllowedHostsSuffix must not be a public suffix, got: {srv_allowed_hosts_suffix}" | ||
| ) | ||
| # Validate the fully qualified domain name. | ||
| try: | ||
| ipaddress.ip_address(fqdn) | ||
|
|
@@ -135,12 +154,16 @@ async def _get_srv_response_and_hosts( | |
| raise ConfigurationError( | ||
| "Invalid SRV host: return address is identical to SRV hostname" | ||
| ) | ||
| try: | ||
| nlist = srv_host.split(".")[1:][-self.__slen :] | ||
| except Exception as exc: | ||
| raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc | ||
| if self.__plist != nlist: | ||
| raise ConfigurationError(f"Invalid SRV host: {node[0]}") | ||
| if self.__srv_allowed_hosts_suffix is not None: | ||
| if not srv_host.endswith(self.__srv_allowed_hosts_suffix): | ||
| raise ConfigurationError(f"Invalid SRV host: {node[0]}") | ||
| else: | ||
| try: | ||
| nlist = srv_host.split(".")[1:][-self.__slen :] | ||
| except Exception as exc: | ||
| raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc | ||
| if self.__plist != nlist: | ||
| raise ConfigurationError(f"Invalid SRV host: {node[0]}") | ||
| if self.__srv_max_hosts: | ||
| nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) | ||
| return results, nodes | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
psl is full of mostly not words,,, checking it raises lots of "typos" but aren't actually typos lol