Skip to content

Commit 25ca16f

Browse files
committed
security: fix SSRF in load_image and load_video URL fetching
1 parent de5fcf6 commit 25ca16f

1 file changed

Lines changed: 46 additions & 0 deletions

File tree

src/diffusers/utils/loading_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import ipaddress
12
import os
3+
import socket
24
import tempfile
35
from typing import Any, Callable
46
from urllib.parse import unquote, urlparse
@@ -11,6 +13,48 @@
1113
from .import_utils import BACKENDS_MAPPING, is_imageio_available
1214

1315

16+
_BLOCKED_IPV4 = [
17+
ipaddress.ip_network("127.0.0.0/8"),
18+
ipaddress.ip_network("10.0.0.0/8"),
19+
ipaddress.ip_network("172.16.0.0/12"),
20+
ipaddress.ip_network("192.168.0.0/16"),
21+
ipaddress.ip_network("169.254.0.0/16"),
22+
ipaddress.ip_network("100.64.0.0/10"),
23+
ipaddress.ip_network("0.0.0.0/8"),
24+
]
25+
_BLOCKED_IPV6 = [
26+
ipaddress.ip_network("::1/128"),
27+
ipaddress.ip_network("fc00::/7"),
28+
ipaddress.ip_network("fe80::/10"),
29+
]
30+
31+
32+
def _validate_url(url: str) -> None:
33+
"""Raise ValueError if url resolves to a private/reserved IP address (SSRF prevention)."""
34+
parsed = urlparse(url)
35+
if parsed.scheme not in ("http", "https"):
36+
raise ValueError(f"URL scheme {parsed.scheme!r} is not allowed.")
37+
if not parsed.hostname:
38+
raise ValueError(f"URL has no hostname: {url!r}")
39+
try:
40+
infos = socket.getaddrinfo(parsed.hostname, parsed.port or (443 if parsed.scheme == "https" else 80))
41+
except socket.gaierror as exc:
42+
raise ValueError(f"Cannot resolve hostname {parsed.hostname!r}: {exc}") from exc
43+
for _family, _, _, _, sockaddr in infos:
44+
try:
45+
addr = ipaddress.ip_address(sockaddr[0])
46+
if isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped:
47+
addr = addr.ipv4_mapped
48+
nets = _BLOCKED_IPV4 if isinstance(addr, ipaddress.IPv4Address) else _BLOCKED_IPV6
49+
if any(addr in net for net in nets):
50+
raise ValueError(
51+
f"URL {url!r} resolves to private/reserved IP address {addr}. "
52+
"Requests to internal network addresses are not allowed."
53+
)
54+
except ValueError:
55+
raise
56+
57+
1458
def load_image(
1559
image: str | PIL.Image.Image, convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] | None = None
1660
) -> PIL.Image.Image:
@@ -30,6 +74,7 @@ def load_image(
3074
"""
3175
if isinstance(image, str):
3276
if image.startswith("http://") or image.startswith("https://"):
77+
_validate_url(image)
3378
image = PIL.Image.open(requests.get(image, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
3479
elif os.path.isfile(image):
3580
image = PIL.Image.open(image)
@@ -82,6 +127,7 @@ def load_video(
82127
)
83128

84129
if is_url:
130+
_validate_url(video)
85131
response = requests.get(video, stream=True)
86132
if response.status_code != 200:
87133
raise ValueError(f"Failed to download video. Status code: {response.status_code}")

0 commit comments

Comments
 (0)