|
| 1 | +""" |
| 2 | +SSRF (Server-Side Request Forgery) protection utilities. |
| 3 | +
|
| 4 | +Provides a requests.Session that validates outbound URLs against private/reserved |
| 5 | +IP ranges at socket-creation time, closing the DNS rebinding (TOCTOU) window that |
| 6 | +exists when validation is performed only as a pre-flight step. |
| 7 | +
|
| 8 | +Usage: |
| 9 | + from dojo.utils_ssrf import make_ssrf_safe_session, validate_url_for_ssrf, SSRFError |
| 10 | +
|
| 11 | + # Pre-flight validation (raises SSRFError with a human-readable message): |
| 12 | + validate_url_for_ssrf(url) |
| 13 | +
|
| 14 | + # Safe session (validates at socket-creation time on every request): |
| 15 | + session = make_ssrf_safe_session() |
| 16 | + response = session.get(url) |
| 17 | +""" |
| 18 | + |
| 19 | +import ipaddress |
| 20 | +import socket |
| 21 | +from urllib.parse import urlparse |
| 22 | + |
| 23 | +import requests |
| 24 | +import urllib3.connection |
| 25 | +import urllib3.connectionpool |
| 26 | +from requests.adapters import DEFAULT_POOLBLOCK, DEFAULT_POOLSIZE, HTTPAdapter |
| 27 | + |
| 28 | + |
| 29 | +class SSRFError(ValueError): |
| 30 | + |
| 31 | + """Raised when a URL is determined to be unsafe for server-side requests.""" |
| 32 | + |
| 33 | + |
| 34 | +_ALLOWED_SCHEMES = frozenset({"http", "https"}) |
| 35 | + |
| 36 | + |
| 37 | +def _check_ip(ip_str: str) -> None: |
| 38 | + """Raise SSRFError if the IP address is not globally routable.""" |
| 39 | + try: |
| 40 | + ip = ipaddress.ip_address(ip_str) |
| 41 | + except ValueError as exc: |
| 42 | + msg = f"Cannot parse IP address: {ip_str!r}" |
| 43 | + raise SSRFError(msg) from exc |
| 44 | + |
| 45 | + # ip.is_global is False for loopback, link-local (169.254.x.x), RFC 1918, |
| 46 | + # reserved, multicast, and unspecified addresses. |
| 47 | + if not ip.is_global: |
| 48 | + msg = ( |
| 49 | + f"Blocked: URL resolved to non-public address {ip}. " |
| 50 | + "Requests to private, loopback, link-local, or reserved " |
| 51 | + "addresses are not permitted." |
| 52 | + ) |
| 53 | + raise SSRFError(msg) |
| 54 | + |
| 55 | + |
| 56 | +def _resolve_and_check(hostname: str, port: int) -> None: |
| 57 | + """Resolve hostname and verify every returned address is publicly routable.""" |
| 58 | + try: |
| 59 | + addr_infos = socket.getaddrinfo( |
| 60 | + hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM, |
| 61 | + ) |
| 62 | + except socket.gaierror as exc: |
| 63 | + msg = f"Unable to resolve hostname {hostname!r}: {exc}" |
| 64 | + raise SSRFError(msg) from exc |
| 65 | + |
| 66 | + if not addr_infos: |
| 67 | + msg = f"No addresses returned for hostname {hostname!r}" |
| 68 | + raise SSRFError(msg) |
| 69 | + |
| 70 | + for _family, _type, _proto, _canon, sockaddr in addr_infos: |
| 71 | + _check_ip(sockaddr[0]) |
| 72 | + |
| 73 | + |
| 74 | +def validate_url_for_ssrf(url: str) -> None: |
| 75 | + """ |
| 76 | + Pre-flight SSRF validation for a URL. |
| 77 | +
|
| 78 | + Checks: |
| 79 | + - Scheme is http or https (blocks file://, gopher://, etc.) |
| 80 | + - Every resolved IP address is globally routable (blocks RFC 1918, |
| 81 | + loopback 127.x, link-local 169.254.x.x, and other reserved ranges) |
| 82 | +
|
| 83 | + Raises SSRFError with a descriptive message if the URL is unsafe. |
| 84 | + This is a best-effort pre-flight check; use make_ssrf_safe_session() for |
| 85 | + socket-level enforcement that also mitigates DNS rebinding. |
| 86 | + """ |
| 87 | + try: |
| 88 | + parsed = urlparse(url) |
| 89 | + except Exception as exc: |
| 90 | + msg = f"Malformed URL: {url!r}" |
| 91 | + raise SSRFError(msg) from exc |
| 92 | + |
| 93 | + if parsed.scheme not in _ALLOWED_SCHEMES: |
| 94 | + msg = ( |
| 95 | + f"URL scheme {parsed.scheme!r} is not permitted. " |
| 96 | + "Only 'http' and 'https' are allowed." |
| 97 | + ) |
| 98 | + raise SSRFError(msg) |
| 99 | + |
| 100 | + hostname = parsed.hostname |
| 101 | + if not hostname: |
| 102 | + msg = f"URL has no hostname: {url!r}" |
| 103 | + raise SSRFError(msg) |
| 104 | + |
| 105 | + port = parsed.port or (443 if parsed.scheme == "https" else 80) |
| 106 | + _resolve_and_check(hostname, port) |
| 107 | + |
| 108 | + |
| 109 | +# --------------------------------------------------------------------------- |
| 110 | +# urllib3 connection subclasses — validation runs at socket-creation time. |
| 111 | +# Overriding _new_conn() (called immediately before the OS connect() syscall) |
| 112 | +# minimises the TOCTOU window to microseconds, making DNS rebinding attacks |
| 113 | +# impractical in practice. |
| 114 | +# --------------------------------------------------------------------------- |
| 115 | + |
| 116 | +class _SSRFSafeHTTPConnection(urllib3.connection.HTTPConnection): |
| 117 | + def _new_conn(self) -> socket.socket: |
| 118 | + _resolve_and_check(self._dns_host, self.port) |
| 119 | + return super()._new_conn() |
| 120 | + |
| 121 | + |
| 122 | +class _SSRFSafeHTTPSConnection(urllib3.connection.HTTPSConnection): |
| 123 | + def _new_conn(self) -> socket.socket: |
| 124 | + _resolve_and_check(self._dns_host, self.port) |
| 125 | + return super()._new_conn() |
| 126 | + |
| 127 | + |
| 128 | +class _SSRFSafeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): |
| 129 | + ConnectionCls = _SSRFSafeHTTPConnection |
| 130 | + |
| 131 | + |
| 132 | +class _SSRFSafeHTTPSConnectionPool(urllib3.connectionpool.HTTPSConnectionPool): |
| 133 | + ConnectionCls = _SSRFSafeHTTPSConnection |
| 134 | + |
| 135 | + |
| 136 | +_SAFE_POOL_CLASSES = { |
| 137 | + "http": _SSRFSafeHTTPConnectionPool, |
| 138 | + "https": _SSRFSafeHTTPSConnectionPool, |
| 139 | +} |
| 140 | + |
| 141 | + |
| 142 | +class _SSRFSafeAdapter(HTTPAdapter): |
| 143 | + |
| 144 | + """ |
| 145 | + A requests HTTPAdapter that injects SSRF-safe connection classes into the |
| 146 | + urllib3 pool manager so that IP validation happens at socket-creation time |
| 147 | + on every request, including after redirects. |
| 148 | + """ |
| 149 | + |
| 150 | + def init_poolmanager(self, connections, maxsize, block=DEFAULT_POOLBLOCK, **pool_kwargs): |
| 151 | + super().init_poolmanager(connections, maxsize, block, **pool_kwargs) |
| 152 | + # Replace the pool classes after the manager is created. |
| 153 | + # pool_classes_by_scheme is a plain dict on the instance, so this |
| 154 | + # only affects this adapter's pool manager. |
| 155 | + self.poolmanager.pool_classes_by_scheme = _SAFE_POOL_CLASSES |
| 156 | + |
| 157 | + |
| 158 | +def make_ssrf_safe_session() -> requests.Session: |
| 159 | + """ |
| 160 | + Return a requests.Session with SSRF protection applied at the socket level. |
| 161 | +
|
| 162 | + Every outbound request made through this session will have its resolved IP |
| 163 | + validated against the private/reserved range blocklist immediately before |
| 164 | + the OS socket is opened, preventing both: |
| 165 | + - Direct requests to internal IP ranges |
| 166 | + - DNS rebinding attacks |
| 167 | + """ |
| 168 | + session = requests.Session() |
| 169 | + adapter = _SSRFSafeAdapter(pool_maxsize=DEFAULT_POOLSIZE) |
| 170 | + session.mount("http://", adapter) |
| 171 | + session.mount("https://", adapter) |
| 172 | + return session |
0 commit comments