Skip to content

Commit aec3fef

Browse files
Jino-Tclaude
andauthored
Added ssrf utils file to check urls and applied it to risk recon parser (#14631)
* added ssrf utils to check urls and applied it to risk recon parser * update risk recon unit tests * add unit tests for SSRF protection in risk recon API init Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * add unit tests for utils_ssrf module Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8a2100e commit aec3fef

File tree

4 files changed

+280
-7
lines changed

4 files changed

+280
-7
lines changed

dojo/tools/risk_recon/api.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import requests
21
from django.conf import settings
32

3+
from dojo.utils_ssrf import SSRFError, make_ssrf_safe_session, validate_url_for_ssrf
4+
45

56
class RiskReconAPI:
67
def __init__(self, api_key, endpoint, data):
@@ -26,7 +27,14 @@ def __init__(self, api_key, endpoint, data):
2627
raise Exception(msg)
2728
if self.url.endswith("/"):
2829
self.url = endpoint[:-1]
29-
self.session = requests.Session()
30+
31+
try:
32+
validate_url_for_ssrf(self.url)
33+
except SSRFError as exc:
34+
msg = f"Invalid Risk Recon API url: {exc}"
35+
raise Exception(msg) from exc
36+
37+
self.session = make_ssrf_safe_session()
3038
self.map_toes()
3139
self.get_findings()
3240

@@ -54,7 +62,7 @@ def map_toes(self):
5462
filters = comps.get(name)
5563
self.toe_map[toe_id] = filters or self.data
5664
else:
57-
msg = f"Unable to query Target of Evaluations due to {response.status_code} - {response.content}"
65+
msg = f"Unable to query Target of Evaluations due to {response.status_code}"
5866
raise Exception(msg) # TODO: when implementing ruff BLE001, please fix also TODO in unittests/test_risk_recon.py
5967

6068
def filter_finding(self, finding):
@@ -86,5 +94,5 @@ def get_findings(self):
8694
if not self.filter_finding(finding):
8795
self.findings.append(finding)
8896
else:
89-
msg = f"Unable to collect findings from toe: {toe} due to {response.status_code} - {response.content}"
97+
msg = f"Unable to collect findings from toe: {toe} due to {response.status_code}"
9098
raise Exception(msg)

dojo/utils_ssrf.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
SSRF (Server-Side Request Forgery) protection utilities.
3+
4+
Provides a requests.Session that validates outbound URLs against private/reserved
5+
IP ranges at socket-creation time, closing the DNS rebinding (TOCTOU) window that
6+
exists when validation is performed only as a pre-flight step.
7+
8+
Usage:
9+
from dojo.utils_ssrf import make_ssrf_safe_session, validate_url_for_ssrf, SSRFError
10+
11+
# Pre-flight validation (raises SSRFError with a human-readable message):
12+
validate_url_for_ssrf(url)
13+
14+
# Safe session (validates at socket-creation time on every request):
15+
session = make_ssrf_safe_session()
16+
response = session.get(url)
17+
"""
18+
19+
import ipaddress
20+
import socket
21+
from urllib.parse import urlparse
22+
23+
import requests
24+
import urllib3.connection
25+
import urllib3.connectionpool
26+
from requests.adapters import DEFAULT_POOLBLOCK, DEFAULT_POOLSIZE, HTTPAdapter
27+
28+
29+
class SSRFError(ValueError):
30+
31+
"""Raised when a URL is determined to be unsafe for server-side requests."""
32+
33+
34+
_ALLOWED_SCHEMES = frozenset({"http", "https"})
35+
36+
37+
def _check_ip(ip_str: str) -> None:
38+
"""Raise SSRFError if the IP address is not globally routable."""
39+
try:
40+
ip = ipaddress.ip_address(ip_str)
41+
except ValueError as exc:
42+
msg = f"Cannot parse IP address: {ip_str!r}"
43+
raise SSRFError(msg) from exc
44+
45+
# ip.is_global is False for loopback, link-local (169.254.x.x), RFC 1918,
46+
# reserved, multicast, and unspecified addresses.
47+
if not ip.is_global:
48+
msg = (
49+
f"Blocked: URL resolved to non-public address {ip}. "
50+
"Requests to private, loopback, link-local, or reserved "
51+
"addresses are not permitted."
52+
)
53+
raise SSRFError(msg)
54+
55+
56+
def _resolve_and_check(hostname: str, port: int) -> None:
57+
"""Resolve hostname and verify every returned address is publicly routable."""
58+
try:
59+
addr_infos = socket.getaddrinfo(
60+
hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM,
61+
)
62+
except socket.gaierror as exc:
63+
msg = f"Unable to resolve hostname {hostname!r}: {exc}"
64+
raise SSRFError(msg) from exc
65+
66+
if not addr_infos:
67+
msg = f"No addresses returned for hostname {hostname!r}"
68+
raise SSRFError(msg)
69+
70+
for _family, _type, _proto, _canon, sockaddr in addr_infos:
71+
_check_ip(sockaddr[0])
72+
73+
74+
def validate_url_for_ssrf(url: str) -> None:
75+
"""
76+
Pre-flight SSRF validation for a URL.
77+
78+
Checks:
79+
- Scheme is http or https (blocks file://, gopher://, etc.)
80+
- Every resolved IP address is globally routable (blocks RFC 1918,
81+
loopback 127.x, link-local 169.254.x.x, and other reserved ranges)
82+
83+
Raises SSRFError with a descriptive message if the URL is unsafe.
84+
This is a best-effort pre-flight check; use make_ssrf_safe_session() for
85+
socket-level enforcement that also mitigates DNS rebinding.
86+
"""
87+
try:
88+
parsed = urlparse(url)
89+
except Exception as exc:
90+
msg = f"Malformed URL: {url!r}"
91+
raise SSRFError(msg) from exc
92+
93+
if parsed.scheme not in _ALLOWED_SCHEMES:
94+
msg = (
95+
f"URL scheme {parsed.scheme!r} is not permitted. "
96+
"Only 'http' and 'https' are allowed."
97+
)
98+
raise SSRFError(msg)
99+
100+
hostname = parsed.hostname
101+
if not hostname:
102+
msg = f"URL has no hostname: {url!r}"
103+
raise SSRFError(msg)
104+
105+
port = parsed.port or (443 if parsed.scheme == "https" else 80)
106+
_resolve_and_check(hostname, port)
107+
108+
109+
# ---------------------------------------------------------------------------
110+
# urllib3 connection subclasses — validation runs at socket-creation time.
111+
# Overriding _new_conn() (called immediately before the OS connect() syscall)
112+
# minimises the TOCTOU window to microseconds, making DNS rebinding attacks
113+
# impractical in practice.
114+
# ---------------------------------------------------------------------------
115+
116+
class _SSRFSafeHTTPConnection(urllib3.connection.HTTPConnection):
117+
def _new_conn(self) -> socket.socket:
118+
_resolve_and_check(self._dns_host, self.port)
119+
return super()._new_conn()
120+
121+
122+
class _SSRFSafeHTTPSConnection(urllib3.connection.HTTPSConnection):
123+
def _new_conn(self) -> socket.socket:
124+
_resolve_and_check(self._dns_host, self.port)
125+
return super()._new_conn()
126+
127+
128+
class _SSRFSafeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
129+
ConnectionCls = _SSRFSafeHTTPConnection
130+
131+
132+
class _SSRFSafeHTTPSConnectionPool(urllib3.connectionpool.HTTPSConnectionPool):
133+
ConnectionCls = _SSRFSafeHTTPSConnection
134+
135+
136+
_SAFE_POOL_CLASSES = {
137+
"http": _SSRFSafeHTTPConnectionPool,
138+
"https": _SSRFSafeHTTPSConnectionPool,
139+
}
140+
141+
142+
class _SSRFSafeAdapter(HTTPAdapter):
143+
144+
"""
145+
A requests HTTPAdapter that injects SSRF-safe connection classes into the
146+
urllib3 pool manager so that IP validation happens at socket-creation time
147+
on every request, including after redirects.
148+
"""
149+
150+
def init_poolmanager(self, connections, maxsize, block=DEFAULT_POOLBLOCK, **pool_kwargs):
151+
super().init_poolmanager(connections, maxsize, block, **pool_kwargs)
152+
# Replace the pool classes after the manager is created.
153+
# pool_classes_by_scheme is a plain dict on the instance, so this
154+
# only affects this adapter's pool manager.
155+
self.poolmanager.pool_classes_by_scheme = _SAFE_POOL_CLASSES
156+
157+
158+
def make_ssrf_safe_session() -> requests.Session:
159+
"""
160+
Return a requests.Session with SSRF protection applied at the socket level.
161+
162+
Every outbound request made through this session will have its resolved IP
163+
validated against the private/reserved range blocklist immediately before
164+
the OS socket is opened, preventing both:
165+
- Direct requests to internal IP ranges
166+
- DNS rebinding attacks
167+
"""
168+
session = requests.Session()
169+
adapter = _SSRFSafeAdapter(pool_maxsize=DEFAULT_POOLSIZE)
170+
session.mount("http://", adapter)
171+
session.mount("https://", adapter)
172+
return session

unittests/test_utils_ssrf.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import socket
2+
from unittest.mock import patch
3+
4+
import requests
5+
6+
from dojo.utils_ssrf import SSRFError, _SSRFSafeAdapter, make_ssrf_safe_session, validate_url_for_ssrf # noqa: PLC2701
7+
from unittests.dojo_test_case import DojoTestCase
8+
9+
10+
def _addr_info(ip, port=80):
11+
"""Build a minimal getaddrinfo-style return value for a single IP."""
12+
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, port))]
13+
14+
15+
_MIXED_ADDR_INFO = [
16+
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("8.8.8.8", 80)),
17+
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("192.168.1.1", 80)),
18+
]
19+
20+
21+
class TestValidateUrlForSsrf(DojoTestCase):
22+
23+
@patch("dojo.utils_ssrf.socket.getaddrinfo", return_value=_addr_info("8.8.8.8"))
24+
def test_valid_public_url_does_not_raise(self, mock_getaddrinfo):
25+
validate_url_for_ssrf("http://example.com/api") # should not raise
26+
27+
def test_file_scheme_raises(self):
28+
with self.assertRaisesRegex(SSRFError, "not permitted"):
29+
validate_url_for_ssrf("file:///etc/passwd")
30+
31+
def test_gopher_scheme_raises(self):
32+
with self.assertRaisesRegex(SSRFError, "not permitted"):
33+
validate_url_for_ssrf("gopher://example.com")
34+
35+
def test_no_hostname_raises(self):
36+
with self.assertRaisesRegex(SSRFError, "no hostname"):
37+
validate_url_for_ssrf("http://")
38+
39+
def test_loopback_ip_raises(self):
40+
with self.assertRaisesRegex(SSRFError, "non-public address"):
41+
validate_url_for_ssrf("http://127.0.0.1/")
42+
43+
def test_private_class_c_raises(self):
44+
with self.assertRaisesRegex(SSRFError, "non-public address"):
45+
validate_url_for_ssrf("http://192.168.1.1/")
46+
47+
def test_private_class_a_raises(self):
48+
with self.assertRaisesRegex(SSRFError, "non-public address"):
49+
validate_url_for_ssrf("http://10.0.0.1/")
50+
51+
def test_link_local_raises(self):
52+
with self.assertRaisesRegex(SSRFError, "non-public address"):
53+
validate_url_for_ssrf("http://169.254.1.1/")
54+
55+
@patch("dojo.utils_ssrf.socket.getaddrinfo", side_effect=socket.gaierror("Name or service not known"))
56+
def test_unresolvable_hostname_raises(self, mock_getaddrinfo):
57+
with self.assertRaisesRegex(SSRFError, "Unable to resolve"):
58+
validate_url_for_ssrf("http://nonexistent.invalid/")
59+
60+
@patch("dojo.utils_ssrf.socket.getaddrinfo", return_value=_MIXED_ADDR_INFO)
61+
def test_multi_address_with_private_ip_raises(self, mock_getaddrinfo):
62+
with self.assertRaisesRegex(SSRFError, "non-public address"):
63+
validate_url_for_ssrf("http://example.com/")
64+
65+
66+
class TestMakeSsrfSafeSession(DojoTestCase):
67+
68+
def test_returns_requests_session(self):
69+
session = make_ssrf_safe_session()
70+
self.assertIsInstance(session, requests.Session)
71+
72+
def test_http_and_https_mounted_with_safe_adapter(self):
73+
session = make_ssrf_safe_session()
74+
self.assertIsInstance(session.get_adapter("http://example.com"), _SSRFSafeAdapter)
75+
self.assertIsInstance(session.get_adapter("https://example.com"), _SSRFSafeAdapter)

unittests/tools/test_risk_recon_parser.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import datetime
2-
3-
import requests
2+
from unittest.mock import MagicMock, patch
43

54
from dojo.models import Test
5+
from dojo.tools.risk_recon.api import RiskReconAPI
66
from dojo.tools.risk_recon.parser import RiskReconParser
7+
from dojo.utils_ssrf import SSRFError
78
from unittests.dojo_test_case import DojoTestCase, get_unit_tests_scans_path
89

910

1011
class TestRiskReconAPIParser(DojoTestCase):
1112

1213
def test_api_with_bad_url(self):
1314
with (get_unit_tests_scans_path("risk_recon") / "bad_url.json").open(encoding="utf-8") as testfile, \
14-
self.assertRaises(requests.exceptions.ConnectionError):
15+
self.assertRaises(Exception): # noqa: B017 # SSRFError is caught and re-raised as Exception in api.py
1516
parser = RiskReconParser()
1617
parser.get_findings(testfile, Test())
1718

@@ -34,3 +35,20 @@ def test_parser_without_api(self):
3435
finding = findings[1]
3536
self.assertEqual(datetime.date(2017, 3, 17), finding.date.date())
3637
self.assertEqual("ff2bbdbfc2b6gsrgwergwe6b1fasfwefb", finding.unique_id_from_tool)
38+
39+
@patch("dojo.tools.risk_recon.api.validate_url_for_ssrf", side_effect=SSRFError("blocked: private address"))
40+
def test_ssrf_error_is_raised_as_exception(self, mock_validate):
41+
with self.assertRaisesRegex(Exception, "Invalid Risk Recon API url"):
42+
RiskReconAPI(api_key="somekey", endpoint="http://192.168.1.1/api", data=[])
43+
mock_validate.assert_called_once_with("http://192.168.1.1/api")
44+
45+
@patch.object(RiskReconAPI, "get_findings")
46+
@patch.object(RiskReconAPI, "map_toes")
47+
@patch("dojo.tools.risk_recon.api.make_ssrf_safe_session")
48+
@patch("dojo.tools.risk_recon.api.validate_url_for_ssrf")
49+
def test_make_ssrf_safe_session_called_on_init(self, mock_validate, mock_make_session, mock_map_toes, mock_get_findings):
50+
mock_session = MagicMock()
51+
mock_make_session.return_value = mock_session
52+
api = RiskReconAPI(api_key="somekey", endpoint="https://api.riskrecon.com/v1", data=[])
53+
mock_make_session.assert_called_once()
54+
self.assertIs(api.session, mock_session)

0 commit comments

Comments
 (0)