1+ import ipaddress
12import os
3+ import socket
24import tempfile
35from typing import Any , Callable
46from urllib .parse import unquote , urlparse
1113from .import_utils import BACKENDS_MAPPING , is_imageio_available
1214
1315
16+ _BLOCKED_IPV4 = [
17+ ipaddress .ip_network ("127.0.0.0/8" ),
18+ ipaddress .ip_network ("10.0.0.0/8" ),
19+ ipaddress .ip_network ("172.16.0.0/12" ),
20+ ipaddress .ip_network ("192.168.0.0/16" ),
21+ ipaddress .ip_network ("169.254.0.0/16" ),
22+ ipaddress .ip_network ("100.64.0.0/10" ),
23+ ipaddress .ip_network ("0.0.0.0/8" ),
24+ ]
25+ _BLOCKED_IPV6 = [
26+ ipaddress .ip_network ("::1/128" ),
27+ ipaddress .ip_network ("fc00::/7" ),
28+ ipaddress .ip_network ("fe80::/10" ),
29+ ]
30+
31+
32+ def _validate_url (url : str ) -> None :
33+ """Raise ValueError if url resolves to a private/reserved IP address (SSRF prevention)."""
34+ parsed = urlparse (url )
35+ if parsed .scheme not in ("http" , "https" ):
36+ raise ValueError (f"URL scheme { parsed .scheme !r} is not allowed." )
37+ if not parsed .hostname :
38+ raise ValueError (f"URL has no hostname: { url !r} " )
39+ try :
40+ infos = socket .getaddrinfo (parsed .hostname , parsed .port or (443 if parsed .scheme == "https" else 80 ))
41+ except socket .gaierror as exc :
42+ raise ValueError (f"Cannot resolve hostname { parsed .hostname !r} : { exc } " ) from exc
43+ for _family , _ , _ , _ , sockaddr in infos :
44+ try :
45+ addr = ipaddress .ip_address (sockaddr [0 ])
46+ if isinstance (addr , ipaddress .IPv6Address ) and addr .ipv4_mapped :
47+ addr = addr .ipv4_mapped
48+ nets = _BLOCKED_IPV4 if isinstance (addr , ipaddress .IPv4Address ) else _BLOCKED_IPV6
49+ if any (addr in net for net in nets ):
50+ raise ValueError (
51+ f"URL { url !r} resolves to private/reserved IP address { addr } . "
52+ "Requests to internal network addresses are not allowed."
53+ )
54+ except ValueError :
55+ raise
56+
57+
1458def load_image (
1559 image : str | PIL .Image .Image , convert_method : Callable [[PIL .Image .Image ], PIL .Image .Image ] | None = None
1660) -> PIL .Image .Image :
@@ -30,6 +74,7 @@ def load_image(
3074 """
3175 if isinstance (image , str ):
3276 if image .startswith ("http://" ) or image .startswith ("https://" ):
77+ _validate_url (image )
3378 image = PIL .Image .open (requests .get (image , stream = True , timeout = DIFFUSERS_REQUEST_TIMEOUT ).raw )
3479 elif os .path .isfile (image ):
3580 image = PIL .Image .open (image )
@@ -82,6 +127,7 @@ def load_video(
82127 )
83128
84129 if is_url :
130+ _validate_url (video )
85131 response = requests .get (video , stream = True )
86132 if response .status_code != 200 :
87133 raise ValueError (f"Failed to download video. Status code: { response .status_code } " )
0 commit comments