Skip to content

Commit 8da414a

Browse files
committed
Add URL validation and request hardening for media input loading
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
1 parent 7ee9e8b commit 8da414a

File tree

1 file changed

+123
-30
lines changed

1 file changed

+123
-30
lines changed

tensorrt_llm/inputs/utils.py

Lines changed: 123 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
22
import base64
3+
import ipaddress
34
import math
5+
import socket
46
import tempfile
57
from collections import defaultdict
68
from dataclasses import dataclass
@@ -95,6 +97,88 @@ def convert_image_mode(image: Image.Image, to_mode: str) -> Image.Image:
9597
return image.convert(to_mode)
9698

9799

100+
# Maximum allowed response size for remote fetches (200 MB).
101+
_MAX_RESPONSE_BYTES = 200 * 1024 * 1024
102+
103+
# Maximum number of redirects allowed for remote fetches.
104+
_MAX_REDIRECTS = 5
105+
106+
107+
def _validate_url(url: str) -> None:
108+
"""Validate that *url* points to a public, non-internal HTTP(S) resource.
109+
110+
Raises ``ValueError`` for URLs that target private, loopback, or
111+
link-local addresses, or that use a scheme other than http / https.
112+
"""
113+
parsed = urlparse(url)
114+
if parsed.scheme not in ("http", "https"):
115+
raise ValueError(
116+
f"Only http and https URLs are allowed, got: {parsed.scheme!r}")
117+
118+
hostname = parsed.hostname
119+
if not hostname:
120+
raise ValueError("URL has no hostname")
121+
122+
# Resolve to IP and check address range.
123+
try:
124+
infos = socket.getaddrinfo(hostname, None, proto=socket.IPPROTO_TCP)
125+
except socket.gaierror as exc:
126+
raise ValueError(f"Could not resolve hostname {hostname!r}") from exc
127+
128+
for _family, _type, _proto, _canon, sockaddr in infos:
129+
ip = ipaddress.ip_address(sockaddr[0])
130+
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
131+
raise ValueError(f"URL resolves to a non-public address ({ip})")
132+
133+
134+
def _safe_request_get(url: str,
135+
*,
136+
stream: bool = False,
137+
timeout: int = 30) -> "requests.Response":
138+
"""``requests.get`` wrapper that validates the URL first."""
139+
_validate_url(url)
140+
resp = requests.get(
141+
url,
142+
stream=stream,
143+
timeout=timeout,
144+
allow_redirects=False,
145+
)
146+
for _ in range(_MAX_REDIRECTS):
147+
if resp.status_code not in (301, 302, 303, 307, 308):
148+
break
149+
redirect_url = resp.headers.get("Location", "")
150+
_validate_url(redirect_url)
151+
resp = requests.get(
152+
redirect_url,
153+
stream=stream,
154+
timeout=timeout,
155+
allow_redirects=False,
156+
)
157+
else:
158+
raise ValueError("Too many redirects")
159+
resp.raise_for_status()
160+
if not stream and len(resp.content) > _MAX_RESPONSE_BYTES:
161+
raise ValueError("Response exceeds maximum allowed size")
162+
return resp
163+
164+
165+
async def _safe_aiohttp_get(url: str, timeout_sec: int = 30) -> bytes:
166+
"""``aiohttp`` GET wrapper that validates URLs before and after redirects."""
167+
_validate_url(url)
168+
timeout = aiohttp.ClientTimeout(total=timeout_sec)
169+
async with aiohttp.ClientSession(timeout=timeout) as session:
170+
async with session.get(url,
171+
max_redirects=_MAX_REDIRECTS,
172+
allow_redirects=True) as response:
173+
# Validate the final (possibly redirected) URL.
174+
_validate_url(str(response.url))
175+
response.raise_for_status()
176+
data = await response.content.read(_MAX_RESPONSE_BYTES + 1)
177+
if len(data) > _MAX_RESPONSE_BYTES:
178+
raise ValueError("Response exceeds maximum allowed size")
179+
return data
180+
181+
98182
def _load_and_convert_image(image):
99183
image = Image.open(image)
100184
image.load()
@@ -134,12 +218,14 @@ def load_image(image: Union[str, Image.Image],
134218
parsed_url = urlparse(image)
135219

136220
if parsed_url.scheme in ["http", "https"]:
137-
image = requests.get(image, stream=True, timeout=10).raw
138-
image = _load_and_convert_image(image)
221+
resp = _safe_request_get(image, stream=True)
222+
image = _load_and_convert_image(resp.raw)
139223
elif parsed_url.scheme == "data":
140224
image = load_base64_image(parsed_url)
141-
else:
225+
elif parsed_url.scheme in ("", "file"):
142226
image = _load_and_convert_image(image)
227+
else:
228+
raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}")
143229

144230
if format == "pt":
145231
return ToTensor()(image).to(device=device)
@@ -159,14 +245,14 @@ async def async_load_image(
159245
parsed_url = urlparse(image)
160246

161247
if parsed_url.scheme in ["http", "https"]:
162-
async with aiohttp.ClientSession() as session:
163-
async with session.get(image) as response:
164-
content = await response.read()
165-
image = _load_and_convert_image(BytesIO(content))
248+
content = await _safe_aiohttp_get(image)
249+
image = _load_and_convert_image(BytesIO(content))
166250
elif parsed_url.scheme == "data":
167251
image = load_base64_image(parsed_url)
168-
else:
252+
elif parsed_url.scheme in ("", "file"):
169253
image = _load_and_convert_image(Path(parsed_url.path))
254+
else:
255+
raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}")
170256

171257
if format == "pt":
172258
return ToTensor()(image).to(device=device)
@@ -271,7 +357,10 @@ def load_video(video: str,
271357
device: str = "cpu") -> VideoData:
272358
parsed_url = urlparse(video)
273359
results = None
274-
if parsed_url.scheme in ["http", "https", ""]:
360+
if parsed_url.scheme in ["http", "https"]:
361+
_validate_url(video)
362+
results = _load_video_by_cv2(video, num_frames, fps, format, device)
363+
elif parsed_url.scheme in ("", "file"):
275364
results = _load_video_by_cv2(video, num_frames, fps, format, device)
276365
elif parsed_url.scheme == "data":
277366
decoded_video = load_base64_video(video)
@@ -298,14 +387,12 @@ async def async_load_video(video: str,
298387
parsed_url = urlparse(video)
299388

300389
if parsed_url.scheme in ["http", "https"]:
301-
async with aiohttp.ClientSession() as session:
302-
async with session.get(video) as response:
303-
with tempfile.NamedTemporaryFile(delete=True,
304-
suffix='.mp4') as tmp:
305-
tmp.write(await response.content.read())
306-
tmp.flush()
307-
results = _load_video_by_cv2(tmp.name, num_frames, fps,
308-
format, device)
390+
video_data = await _safe_aiohttp_get(video)
391+
with tempfile.NamedTemporaryFile(delete=True, suffix='.mp4') as tmp:
392+
tmp.write(video_data)
393+
tmp.flush()
394+
results = _load_video_by_cv2(tmp.name, num_frames, fps, format,
395+
device)
309396
elif parsed_url.scheme == "data":
310397
decoded_video = load_base64_video(video)
311398
# TODO: any ways to read videos from memory, instead of writing to a tempfile?
@@ -315,8 +402,10 @@ async def async_load_video(video: str,
315402
tmp_file.flush()
316403
results = _load_video_by_cv2(tmp_file.name, num_frames, fps, format,
317404
device)
318-
else:
405+
elif parsed_url.scheme in ("", "file"):
319406
results = _load_video_by_cv2(video, num_frames, fps, format, device)
407+
else:
408+
raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}")
320409
return results
321410

322411

@@ -335,10 +424,13 @@ def load_audio(
335424
) -> Tuple[np.ndarray, int]:
336425
parsed_url = urlparse(audio)
337426
if parsed_url.scheme in ["http", "https"]:
338-
audio = requests.get(audio, stream=True, timeout=10)
339-
audio = BytesIO(audio.content)
340-
elif parsed_url.scheme == "file":
341-
audio = _normalize_file_uri(audio)
427+
resp = _safe_request_get(audio, stream=False)
428+
audio = BytesIO(resp.content)
429+
elif parsed_url.scheme in ("", "file"):
430+
audio = _normalize_file_uri(
431+
audio) if parsed_url.scheme == "file" else audio
432+
else:
433+
raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}")
342434

343435
audio = soundfile.read(audio)
344436
return audio
@@ -351,11 +443,13 @@ async def async_load_audio(
351443
) -> Tuple[np.ndarray, int]:
352444
parsed_url = urlparse(audio)
353445
if parsed_url.scheme in ["http", "https"]:
354-
async with aiohttp.ClientSession() as session:
355-
async with session.get(audio) as response:
356-
audio = BytesIO(await response.content.read())
357-
elif parsed_url.scheme == "file":
358-
audio = _normalize_file_uri(audio)
446+
audio_data = await _safe_aiohttp_get(audio)
447+
audio = BytesIO(audio_data)
448+
elif parsed_url.scheme in ("", "file"):
449+
audio = _normalize_file_uri(
450+
audio) if parsed_url.scheme == "file" else audio
451+
else:
452+
raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}")
359453

360454
audio = soundfile.read(audio)
361455
return audio
@@ -364,9 +458,8 @@ async def async_load_audio(
364458
def encode_base64_content_from_url(content_url: str) -> str:
365459
"""Encode a content retrieved from a remote url to base64 format."""
366460

367-
with requests.get(content_url, timeout=10) as response:
368-
response.raise_for_status()
369-
result = base64.b64encode(response.content).decode('utf-8')
461+
resp = _safe_request_get(content_url, stream=False)
462+
result = base64.b64encode(resp.content).decode('utf-8')
370463

371464
return result
372465

0 commit comments

Comments
 (0)