11import asyncio
22import base64
3+ import ipaddress
34import math
5+ import socket
46import tempfile
57from collections import defaultdict
68from dataclasses import dataclass
@@ -95,6 +97,88 @@ def convert_image_mode(image: Image.Image, to_mode: str) -> Image.Image:
9597 return image .convert (to_mode )
9698
9799
100+ # Maximum allowed response size for remote fetches (200 MB).
101+ _MAX_RESPONSE_BYTES = 200 * 1024 * 1024
102+
103+ # Maximum number of redirects allowed for remote fetches.
104+ _MAX_REDIRECTS = 5
105+
106+
107+ def _validate_url (url : str ) -> None :
108+ """Validate that *url* points to a public, non-internal HTTP(S) resource.
109+
110+ Raises ``ValueError`` for URLs that target private, loopback, or
111+ link-local addresses, or that use a scheme other than http / https.
112+ """
113+ parsed = urlparse (url )
114+ if parsed .scheme not in ("http" , "https" ):
115+ raise ValueError (
116+ f"Only http and https URLs are allowed, got: { parsed .scheme !r} " )
117+
118+ hostname = parsed .hostname
119+ if not hostname :
120+ raise ValueError ("URL has no hostname" )
121+
122+ # Resolve to IP and check address range.
123+ try :
124+ infos = socket .getaddrinfo (hostname , None , proto = socket .IPPROTO_TCP )
125+ except socket .gaierror as exc :
126+ raise ValueError (f"Could not resolve hostname { hostname !r} " ) from exc
127+
128+ for _family , _type , _proto , _canon , sockaddr in infos :
129+ ip = ipaddress .ip_address (sockaddr [0 ])
130+ if ip .is_private or ip .is_loopback or ip .is_link_local or ip .is_reserved :
131+ raise ValueError (f"URL resolves to a non-public address ({ ip } )" )
132+
133+
134+ def _safe_request_get (url : str ,
135+ * ,
136+ stream : bool = False ,
137+ timeout : int = 30 ) -> "requests.Response" :
138+ """``requests.get`` wrapper that validates the URL first."""
139+ _validate_url (url )
140+ resp = requests .get (
141+ url ,
142+ stream = stream ,
143+ timeout = timeout ,
144+ allow_redirects = False ,
145+ )
146+ for _ in range (_MAX_REDIRECTS ):
147+ if resp .status_code not in (301 , 302 , 303 , 307 , 308 ):
148+ break
149+ redirect_url = resp .headers .get ("Location" , "" )
150+ _validate_url (redirect_url )
151+ resp = requests .get (
152+ redirect_url ,
153+ stream = stream ,
154+ timeout = timeout ,
155+ allow_redirects = False ,
156+ )
157+ else :
158+ raise ValueError ("Too many redirects" )
159+ resp .raise_for_status ()
160+ if not stream and len (resp .content ) > _MAX_RESPONSE_BYTES :
161+ raise ValueError ("Response exceeds maximum allowed size" )
162+ return resp
163+
164+
165+ async def _safe_aiohttp_get (url : str , timeout_sec : int = 30 ) -> bytes :
166+ """``aiohttp`` GET wrapper that validates URLs before and after redirects."""
167+ _validate_url (url )
168+ timeout = aiohttp .ClientTimeout (total = timeout_sec )
169+ async with aiohttp .ClientSession (timeout = timeout ) as session :
170+ async with session .get (url ,
171+ max_redirects = _MAX_REDIRECTS ,
172+ allow_redirects = True ) as response :
173+ # Validate the final (possibly redirected) URL.
174+ _validate_url (str (response .url ))
175+ response .raise_for_status ()
176+ data = await response .content .read (_MAX_RESPONSE_BYTES + 1 )
177+ if len (data ) > _MAX_RESPONSE_BYTES :
178+ raise ValueError ("Response exceeds maximum allowed size" )
179+ return data
180+
181+
98182def _load_and_convert_image (image ):
99183 image = Image .open (image )
100184 image .load ()
@@ -134,12 +218,14 @@ def load_image(image: Union[str, Image.Image],
134218 parsed_url = urlparse (image )
135219
136220 if parsed_url .scheme in ["http" , "https" ]:
137- image = requests . get (image , stream = True , timeout = 10 ). raw
138- image = _load_and_convert_image (image )
221+ resp = _safe_request_get (image , stream = True )
222+ image = _load_and_convert_image (resp . raw )
139223 elif parsed_url .scheme == "data" :
140224 image = load_base64_image (parsed_url )
141- else :
225+ elif parsed_url . scheme in ( "" , "file" ) :
142226 image = _load_and_convert_image (image )
227+ else :
228+ raise ValueError (f"Unsupported URL scheme: { parsed_url .scheme !r} " )
143229
144230 if format == "pt" :
145231 return ToTensor ()(image ).to (device = device )
@@ -159,14 +245,14 @@ async def async_load_image(
159245 parsed_url = urlparse (image )
160246
161247 if parsed_url .scheme in ["http" , "https" ]:
162- async with aiohttp .ClientSession () as session :
163- async with session .get (image ) as response :
164- content = await response .read ()
165- image = _load_and_convert_image (BytesIO (content ))
248+ content = await _safe_aiohttp_get (image )
249+ image = _load_and_convert_image (BytesIO (content ))
166250 elif parsed_url .scheme == "data" :
167251 image = load_base64_image (parsed_url )
168- else :
252+ elif parsed_url . scheme in ( "" , "file" ) :
169253 image = _load_and_convert_image (Path (parsed_url .path ))
254+ else :
255+ raise ValueError (f"Unsupported URL scheme: { parsed_url .scheme !r} " )
170256
171257 if format == "pt" :
172258 return ToTensor ()(image ).to (device = device )
@@ -271,7 +357,10 @@ def load_video(video: str,
271357 device : str = "cpu" ) -> VideoData :
272358 parsed_url = urlparse (video )
273359 results = None
274- if parsed_url .scheme in ["http" , "https" , "" ]:
360+ if parsed_url .scheme in ["http" , "https" ]:
361+ _validate_url (video )
362+ results = _load_video_by_cv2 (video , num_frames , fps , format , device )
363+ elif parsed_url .scheme in ("" , "file" ):
275364 results = _load_video_by_cv2 (video , num_frames , fps , format , device )
276365 elif parsed_url .scheme == "data" :
277366 decoded_video = load_base64_video (video )
@@ -298,14 +387,12 @@ async def async_load_video(video: str,
298387 parsed_url = urlparse (video )
299388
300389 if parsed_url .scheme in ["http" , "https" ]:
301- async with aiohttp .ClientSession () as session :
302- async with session .get (video ) as response :
303- with tempfile .NamedTemporaryFile (delete = True ,
304- suffix = '.mp4' ) as tmp :
305- tmp .write (await response .content .read ())
306- tmp .flush ()
307- results = _load_video_by_cv2 (tmp .name , num_frames , fps ,
308- format , device )
390+ video_data = await _safe_aiohttp_get (video )
391+ with tempfile .NamedTemporaryFile (delete = True , suffix = '.mp4' ) as tmp :
392+ tmp .write (video_data )
393+ tmp .flush ()
394+ results = _load_video_by_cv2 (tmp .name , num_frames , fps , format ,
395+ device )
309396 elif parsed_url .scheme == "data" :
310397 decoded_video = load_base64_video (video )
311398 # TODO: any ways to read videos from memory, instead of writing to a tempfile?
@@ -315,8 +402,10 @@ async def async_load_video(video: str,
315402 tmp_file .flush ()
316403 results = _load_video_by_cv2 (tmp_file .name , num_frames , fps , format ,
317404 device )
318- else :
405+ elif parsed_url . scheme in ( "" , "file" ) :
319406 results = _load_video_by_cv2 (video , num_frames , fps , format , device )
407+ else :
408+ raise ValueError (f"Unsupported URL scheme: { parsed_url .scheme !r} " )
320409 return results
321410
322411
@@ -335,10 +424,13 @@ def load_audio(
335424) -> Tuple [np .ndarray , int ]:
336425 parsed_url = urlparse (audio )
337426 if parsed_url .scheme in ["http" , "https" ]:
338- audio = requests .get (audio , stream = True , timeout = 10 )
339- audio = BytesIO (audio .content )
340- elif parsed_url .scheme == "file" :
341- audio = _normalize_file_uri (audio )
427+ resp = _safe_request_get (audio , stream = False )
428+ audio = BytesIO (resp .content )
429+ elif parsed_url .scheme in ("" , "file" ):
430+ audio = _normalize_file_uri (
431+ audio ) if parsed_url .scheme == "file" else audio
432+ else :
433+ raise ValueError (f"Unsupported URL scheme: { parsed_url .scheme !r} " )
342434
343435 audio = soundfile .read (audio )
344436 return audio
@@ -351,11 +443,13 @@ async def async_load_audio(
351443) -> Tuple [np .ndarray , int ]:
352444 parsed_url = urlparse (audio )
353445 if parsed_url .scheme in ["http" , "https" ]:
354- async with aiohttp .ClientSession () as session :
355- async with session .get (audio ) as response :
356- audio = BytesIO (await response .content .read ())
357- elif parsed_url .scheme == "file" :
358- audio = _normalize_file_uri (audio )
446+ audio_data = await _safe_aiohttp_get (audio )
447+ audio = BytesIO (audio_data )
448+ elif parsed_url .scheme in ("" , "file" ):
449+ audio = _normalize_file_uri (
450+ audio ) if parsed_url .scheme == "file" else audio
451+ else :
452+ raise ValueError (f"Unsupported URL scheme: { parsed_url .scheme !r} " )
359453
360454 audio = soundfile .read (audio )
361455 return audio
@@ -364,9 +458,8 @@ async def async_load_audio(
364458def encode_base64_content_from_url (content_url : str ) -> str :
365459 """Encode a content retrieved from a remote url to base64 format."""
366460
367- with requests .get (content_url , timeout = 10 ) as response :
368- response .raise_for_status ()
369- result = base64 .b64encode (response .content ).decode ('utf-8' )
461+ resp = _safe_request_get (content_url , stream = False )
462+ result = base64 .b64encode (resp .content ).decode ('utf-8' )
370463
371464 return result
372465
0 commit comments