Skip to content

Commit 260e051

Browse files
committed
fix(inputs): validate each redirect hop and eliminate TOCTOU in video loading
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
1 parent 8da414a commit 260e051

1 file changed

Lines changed: 23 additions & 13 deletions

File tree

tensorrt_llm/inputs/utils.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -163,20 +163,25 @@ def _safe_request_get(url: str,
163163

164164

165165
async def _safe_aiohttp_get(url: str, timeout_sec: int = 30) -> bytes:
166-
"""``aiohttp`` GET wrapper that validates URLs before and after redirects."""
166+
"""aiohttp GET wrapper that validates every redirect hop before following."""
167167
_validate_url(url)
168168
timeout = aiohttp.ClientTimeout(total=timeout_sec)
169+
current_url = url
169170
async with aiohttp.ClientSession(timeout=timeout) as session:
170-
async with session.get(url,
171-
max_redirects=_MAX_REDIRECTS,
172-
allow_redirects=True) as response:
173-
# Validate the final (possibly redirected) URL.
174-
_validate_url(str(response.url))
175-
response.raise_for_status()
176-
data = await response.content.read(_MAX_RESPONSE_BYTES + 1)
177-
if len(data) > _MAX_RESPONSE_BYTES:
178-
raise ValueError("Response exceeds maximum allowed size")
179-
return data
171+
for _ in range(_MAX_REDIRECTS + 1):
172+
async with session.get(current_url,
173+
allow_redirects=False) as response:
174+
if response.status in (301, 302, 303, 307, 308):
175+
redirect_url = response.headers.get("Location", "")
176+
_validate_url(redirect_url)
177+
current_url = redirect_url
178+
continue
179+
response.raise_for_status()
180+
data = await response.content.read(_MAX_RESPONSE_BYTES + 1)
181+
if len(data) > _MAX_RESPONSE_BYTES:
182+
raise ValueError("Response exceeds maximum allowed size")
183+
return data
184+
raise ValueError("Too many redirects")
180185

181186

182187
def _load_and_convert_image(image):
@@ -358,8 +363,13 @@ def load_video(video: str,
358363
parsed_url = urlparse(video)
359364
results = None
360365
if parsed_url.scheme in ["http", "https"]:
361-
_validate_url(video)
362-
results = _load_video_by_cv2(video, num_frames, fps, format, device)
366+
resp = _safe_request_get(video, stream=False)
367+
with tempfile.NamedTemporaryFile(delete=True,
368+
suffix=".mp4") as tmp_file:
369+
tmp_file.write(resp.content)
370+
tmp_file.flush()
371+
results = _load_video_by_cv2(tmp_file.name, num_frames, fps, format,
372+
device)
363373
elif parsed_url.scheme in ("", "file"):
364374
results = _load_video_by_cv2(video, num_frames, fps, format, device)
365375
elif parsed_url.scheme == "data":

0 commit comments

Comments
 (0)