From 92cfc164bf66f3861a8563a335e41e72f7a66098 Mon Sep 17 00:00:00 2001 From: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> Date: Fri, 3 Apr 2026 20:59:58 +0000 Subject: [PATCH 1/6] Add URL validation and request hardening for media input loading Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> --- tensorrt_llm/inputs/utils.py | 162 +++++++++++++++++++++++++++-------- 1 file changed, 124 insertions(+), 38 deletions(-) diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 355e4e0f40ae..9fcc6ff93ca3 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -1,7 +1,9 @@ import asyncio import base64 +import ipaddress import math import os +import socket import tempfile from collections import defaultdict from dataclasses import dataclass @@ -111,6 +113,88 @@ def convert_image_mode(image: Image.Image, to_mode: str) -> Image.Image: return image.convert(to_mode) +# Maximum allowed response size for remote fetches (200 MB). +_MAX_RESPONSE_BYTES = 200 * 1024 * 1024 + +# Maximum number of redirects allowed for remote fetches. +_MAX_REDIRECTS = 5 + + +def _validate_url(url: str) -> None: + """Validate that *url* points to a public, non-internal HTTP(S) resource. + + Raises ``ValueError`` for URLs that target private, loopback, or + link-local addresses, or that use a scheme other than http / https. + """ + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + raise ValueError( + f"Only http and https URLs are allowed, got: {parsed.scheme!r}") + + hostname = parsed.hostname + if not hostname: + raise ValueError("URL has no hostname") + + # Resolve to IP and check address range. + try: + infos = socket.getaddrinfo(hostname, None, proto=socket.IPPROTO_TCP) + except socket.gaierror as exc: + raise ValueError(f"Could not resolve hostname {hostname!r}") from exc + + for _family, _type, _proto, _canon, sockaddr in infos: + ip = ipaddress.ip_address(sockaddr[0]) + if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: + raise ValueError(f"URL resolves to a non-public address ({ip})") + + +def _safe_request_get(url: str, + *, + stream: bool = False, + timeout: int = 30) -> "requests.Response": + """``requests.get`` wrapper that validates the URL first.""" + _validate_url(url) + resp = requests.get( + url, + stream=stream, + timeout=timeout, + allow_redirects=False, + ) + for _ in range(_MAX_REDIRECTS): + if resp.status_code not in (301, 302, 303, 307, 308): + break + redirect_url = resp.headers.get("Location", "") + _validate_url(redirect_url) + resp = requests.get( + redirect_url, + stream=stream, + timeout=timeout, + allow_redirects=False, + ) + else: + raise ValueError("Too many redirects") + resp.raise_for_status() + if not stream and len(resp.content) > _MAX_RESPONSE_BYTES: + raise ValueError("Response exceeds maximum allowed size") + return resp + + +async def _safe_aiohttp_get(url: str, timeout_sec: int = 30) -> bytes: + """``aiohttp`` GET wrapper that validates URLs before and after redirects.""" + _validate_url(url) + timeout = aiohttp.ClientTimeout(total=timeout_sec) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url, + max_redirects=_MAX_REDIRECTS, + allow_redirects=True) as response: + # Validate the final (possibly redirected) URL. + _validate_url(str(response.url)) + response.raise_for_status() + data = await response.content.read(_MAX_RESPONSE_BYTES + 1) + if len(data) > _MAX_RESPONSE_BYTES: + raise ValueError("Response exceeds maximum allowed size") + return data + + def _load_and_convert_image(image): image = Image.open(image) image.load() @@ -150,12 +234,14 @@ def load_image(image: Union[str, Image.Image], parsed_url = urlparse(image) if parsed_url.scheme in ["http", "https"]: - image = requests.get(image, stream=True, timeout=10).raw - image = _load_and_convert_image(image) + resp = _safe_request_get(image, stream=True) + image = _load_and_convert_image(resp.raw) elif parsed_url.scheme == "data": image = load_base64_image(parsed_url) - else: + elif parsed_url.scheme in ("", "file"): image = _load_and_convert_image(image) + else: + raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") if format == "pt": return ToTensor()(image).to(device=device) @@ -175,16 +261,14 @@ async def async_load_image( parsed_url = urlparse(image) if parsed_url.scheme in ["http", "https"]: - session = await _get_aiohttp_session() - async with session.get(image) as response: - content = await response.read() - image = await asyncio.to_thread(_load_and_convert_image, - BytesIO(content)) + content = await _safe_aiohttp_get(image) + image = _load_and_convert_image(BytesIO(content)) elif parsed_url.scheme == "data": - image = await asyncio.to_thread(load_base64_image, parsed_url) + image = load_base64_image(parsed_url) + elif parsed_url.scheme in ("", "file"): + image = _load_and_convert_image(Path(parsed_url.path)) else: - image = await asyncio.to_thread(_load_and_convert_image, - Path(parsed_url.path)) + raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") if format == "pt": return await asyncio.to_thread(lambda: ToTensor() @@ -394,12 +478,12 @@ def load_video(video: str, device: str = "cpu", extract_audio: bool = False) -> VideoData: parsed_url = urlparse(video) - if parsed_url.scheme in ["http", "https", ""]: - return _load_video_by_cv2(video, - num_frames, - fps, - format, - device, + if parsed_url.scheme in ["http", "https"]: + _validate_url(video) + return _load_video_by_cv2(video, num_frames, fps, format, device, + extract_audio=extract_audio) + elif parsed_url.scheme in ("", "file"): + return _load_video_by_cv2(video, num_frames, fps, format, device, extract_audio=extract_audio) elif parsed_url.scheme == "data": decoded_video = load_base64_video(video) @@ -439,16 +523,17 @@ def _load_from_bytes(data: bytes) -> VideoData: extract_audio=extract_audio) if parsed_url.scheme in ["http", "https"]: - session = await _get_aiohttp_session() - async with session.get(video) as response: - content = await response.content.read() - return await asyncio.to_thread(_load_from_bytes, content) + video_data = await _safe_aiohttp_get(video) + return _load_from_bytes(video_data) elif parsed_url.scheme == "data": decoded_video = load_base64_video(video) - return await asyncio.to_thread(_load_from_bytes, decoded_video) - else: + return _load_from_bytes(decoded_video) + elif parsed_url.scheme in ("", "file"): return await asyncio.to_thread(_load_video_by_cv2, video, num_frames, - fps, format, device) + fps, format, device, + extract_audio=extract_audio) + else: + raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") def _normalize_file_uri(uri: str) -> str: @@ -466,10 +551,13 @@ def load_audio( ) -> Tuple[np.ndarray, int]: parsed_url = urlparse(audio) if parsed_url.scheme in ["http", "https"]: - audio = requests.get(audio, stream=True, timeout=10) - audio = BytesIO(audio.content) - elif parsed_url.scheme == "file": - audio = _normalize_file_uri(audio) + resp = _safe_request_get(audio, stream=False) + audio = BytesIO(resp.content) + elif parsed_url.scheme in ("", "file"): + audio = _normalize_file_uri( + audio) if parsed_url.scheme == "file" else audio + else: + raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") audio = soundfile.read(audio) return audio @@ -488,13 +576,12 @@ async def async_load_audio( parsed_url = urlparse(audio) if parsed_url.scheme in ["http", "https"]: - session = await _get_aiohttp_session() - async with session.get(audio) as response: - content = await response.content.read() - # Offload CPU-bound soundfile decoding to thread pool - return await asyncio.to_thread(soundfile.read, BytesIO(content)) - elif parsed_url.scheme == "file": - audio = _normalize_file_uri(audio) + audio_data = await _safe_aiohttp_get(audio) + audio = BytesIO(audio_data) + elif parsed_url.scheme in ("", "file"): + audio = _normalize_file_uri(audio) if parsed_url.scheme == "file" else audio + else: + raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") return await asyncio.to_thread(soundfile.read, audio) @@ -502,9 +589,8 @@ async def async_load_audio( def encode_base64_content_from_url(content_url: str) -> str: """Encode a content retrieved from a remote url to base64 format.""" - with requests.get(content_url, timeout=10) as response: - response.raise_for_status() - result = base64.b64encode(response.content).decode('utf-8') + resp = _safe_request_get(content_url, stream=False) + result = base64.b64encode(resp.content).decode('utf-8') return result From 8c3a795d87a69b26a73fcec15a19bf8d5b717a15 Mon Sep 17 00:00:00 2001 From: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> Date: Fri, 3 Apr 2026 23:03:59 +0000 Subject: [PATCH 2/6] fix(inputs): validate each redirect hop and eliminate TOCTOU in video loading Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> --- tensorrt_llm/inputs/utils.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 9fcc6ff93ca3..f68ffd89289f 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -179,20 +179,25 @@ def _safe_request_get(url: str, async def _safe_aiohttp_get(url: str, timeout_sec: int = 30) -> bytes: - """``aiohttp`` GET wrapper that validates URLs before and after redirects.""" + """aiohttp GET wrapper that validates every redirect hop before following.""" _validate_url(url) timeout = aiohttp.ClientTimeout(total=timeout_sec) + current_url = url async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(url, - max_redirects=_MAX_REDIRECTS, - allow_redirects=True) as response: - # Validate the final (possibly redirected) URL. - _validate_url(str(response.url)) - response.raise_for_status() - data = await response.content.read(_MAX_RESPONSE_BYTES + 1) - if len(data) > _MAX_RESPONSE_BYTES: - raise ValueError("Response exceeds maximum allowed size") - return data + for _ in range(_MAX_REDIRECTS + 1): + async with session.get(current_url, + allow_redirects=False) as response: + if response.status in (301, 302, 303, 307, 308): + redirect_url = response.headers.get("Location", "") + _validate_url(redirect_url) + current_url = redirect_url + continue + response.raise_for_status() + data = await response.content.read(_MAX_RESPONSE_BYTES + 1) + if len(data) > _MAX_RESPONSE_BYTES: + raise ValueError("Response exceeds maximum allowed size") + return data + raise ValueError("Too many redirects") def _load_and_convert_image(image): @@ -479,9 +484,12 @@ def load_video(video: str, extract_audio: bool = False) -> VideoData: parsed_url = urlparse(video) if parsed_url.scheme in ["http", "https"]: - _validate_url(video) - return _load_video_by_cv2(video, num_frames, fps, format, device, - extract_audio=extract_audio) + resp = _safe_request_get(video, stream=False) + with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as tmp_file: + tmp_file.write(resp.content) + tmp_file.flush() + return _load_video_by_cv2(tmp_file.name, num_frames, fps, format, + device, extract_audio=extract_audio) elif parsed_url.scheme in ("", "file"): return _load_video_by_cv2(video, num_frames, fps, format, device, extract_audio=extract_audio) From 0791e2c57e9e8ef63a8007895385276dbd59cfb3 Mon Sep 17 00:00:00 2001 From: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> Date: Fri, 3 Apr 2026 23:14:48 +0000 Subject: [PATCH 3/6] fix(inputs): close intermediate response before redirect in _safe_request_get Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> --- tensorrt_llm/inputs/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index f68ffd89289f..410ee5cd9c97 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -164,6 +164,7 @@ def _safe_request_get(url: str, break redirect_url = resp.headers.get("Location", "") _validate_url(redirect_url) + resp.close() resp = requests.get( redirect_url, stream=stream, From 821f56c805a0bf3fc515cc9b35722bd10a28096a Mon Sep 17 00:00:00 2001 From: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> Date: Fri, 10 Apr 2026 10:50:43 -0700 Subject: [PATCH 4/6] fix(inputs): use RuntimeError in URL validation helpers and add tests Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> --- tensorrt_llm/inputs/utils.py | 24 +- tests/unittest/inputs/test_url_validation.py | 266 +++++++++++++++++++ 2 files changed, 281 insertions(+), 9 deletions(-) create mode 100644 tests/unittest/inputs/test_url_validation.py diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 410ee5cd9c97..04e7e26cb878 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -123,28 +123,34 @@ def convert_image_mode(image: Image.Image, to_mode: str) -> Image.Image: def _validate_url(url: str) -> None: """Validate that *url* points to a public, non-internal HTTP(S) resource. - Raises ``ValueError`` for URLs that target private, loopback, or + Raises ``RuntimeError`` for URLs that target private, loopback, or link-local addresses, or that use a scheme other than http / https. + + Note: validation is performed at DNS-resolution time. A DNS-rebinding + attack (TTL=0, resolves to a public IP during validation then a private IP + during the actual TCP connect) could bypass this check. For strict + isolation, supplement with network-level egress filtering that blocks + RFC-1918 and APIPA ranges at the host firewall. """ parsed = urlparse(url) if parsed.scheme not in ("http", "https"): - raise ValueError( + raise RuntimeError( f"Only http and https URLs are allowed, got: {parsed.scheme!r}") hostname = parsed.hostname if not hostname: - raise ValueError("URL has no hostname") + raise RuntimeError("URL has no hostname") # Resolve to IP and check address range. try: infos = socket.getaddrinfo(hostname, None, proto=socket.IPPROTO_TCP) except socket.gaierror as exc: - raise ValueError(f"Could not resolve hostname {hostname!r}") from exc + raise RuntimeError(f"Could not resolve hostname {hostname!r}") from exc for _family, _type, _proto, _canon, sockaddr in infos: ip = ipaddress.ip_address(sockaddr[0]) if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: - raise ValueError(f"URL resolves to a non-public address ({ip})") + raise RuntimeError(f"URL resolves to a non-public address ({ip})") def _safe_request_get(url: str, @@ -172,10 +178,10 @@ def _safe_request_get(url: str, allow_redirects=False, ) else: - raise ValueError("Too many redirects") + raise RuntimeError("Too many redirects") resp.raise_for_status() if not stream and len(resp.content) > _MAX_RESPONSE_BYTES: - raise ValueError("Response exceeds maximum allowed size") + raise RuntimeError("Response exceeds maximum allowed size") return resp @@ -196,9 +202,9 @@ async def _safe_aiohttp_get(url: str, timeout_sec: int = 30) -> bytes: response.raise_for_status() data = await response.content.read(_MAX_RESPONSE_BYTES + 1) if len(data) > _MAX_RESPONSE_BYTES: - raise ValueError("Response exceeds maximum allowed size") + raise RuntimeError("Response exceeds maximum allowed size") return data - raise ValueError("Too many redirects") + raise RuntimeError("Too many redirects") def _load_and_convert_image(image): diff --git a/tests/unittest/inputs/test_url_validation.py b/tests/unittest/inputs/test_url_validation.py new file mode 100644 index 000000000000..65c1f7ff07f9 --- /dev/null +++ b/tests/unittest/inputs/test_url_validation.py @@ -0,0 +1,266 @@ +"""Unit tests for SSRF-prevention URL validation helpers in inputs/utils.py. + +Tests cover _validate_url(), _safe_request_get(), and _safe_aiohttp_get() +without making real network connections. +""" + +import asyncio +import socket +from unittest.mock import MagicMock, patch + +import pytest + +from tensorrt_llm.inputs.utils import ( + _MAX_REDIRECTS, + _MAX_RESPONSE_BYTES, + _safe_aiohttp_get, + _safe_request_get, + _validate_url, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _dns(ip: str): + """Return a minimal getaddrinfo result that resolves to *ip*.""" + return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (ip, 0))] + + +PUBLIC_DNS = _dns("93.184.216.34") # example.com + + +# --------------------------------------------------------------------------- +# _validate_url +# --------------------------------------------------------------------------- + + +class TestValidateUrl: + def test_rejects_file_scheme(self): + with pytest.raises(RuntimeError, match="Only http"): + _validate_url("file:///etc/passwd") + + def test_rejects_ftp_scheme(self): + with pytest.raises(RuntimeError, match="Only http"): + _validate_url("ftp://example.com/file") + + def test_rejects_missing_hostname(self): + with pytest.raises(RuntimeError, match="no hostname"): + _validate_url("http:///path") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("127.0.0.1")) + def test_rejects_loopback_ipv4(self, _): + with pytest.raises(RuntimeError, match="non-public"): + _validate_url("http://localhost/") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("::1")) + def test_rejects_loopback_ipv6(self, _): + with pytest.raises(RuntimeError, match="non-public"): + _validate_url("http://ip6-localhost/") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("10.0.0.1")) + def test_rejects_rfc1918_10(self, _): + with pytest.raises(RuntimeError, match="non-public"): + _validate_url("http://internal.corp/") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("172.16.0.1")) + def test_rejects_rfc1918_172(self, _): + with pytest.raises(RuntimeError, match="non-public"): + _validate_url("http://vpn.example.com/") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("192.168.1.100")) + def test_rejects_rfc1918_192(self, _): + with pytest.raises(RuntimeError, match="non-public"): + _validate_url("http://router.local/") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("169.254.169.254")) + def test_rejects_cloud_metadata_imds(self, _): + """AWS / Azure / GCP instance metadata service must be blocked.""" + with pytest.raises(RuntimeError, match="non-public"): + _validate_url("http://169.254.169.254/latest/meta-data/") + + @patch( + "tensorrt_llm.inputs.utils.socket.getaddrinfo", + side_effect=socket.gaierror("Name or service not known"), + ) + def test_rejects_unresolvable_hostname(self, _): + with pytest.raises(RuntimeError, match="Could not resolve"): + _validate_url("http://this.does.not.exist.invalid/") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + def test_accepts_public_hostname(self, _): + _validate_url("http://example.com/image.jpg") # must not raise + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + def test_accepts_https(self, _): + _validate_url("https://example.com/image.jpg") # must not raise + + +# --------------------------------------------------------------------------- +# _safe_request_get +# --------------------------------------------------------------------------- + + +class TestSafeRequestGet: + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + @patch("tensorrt_llm.inputs.utils.requests.get") + def test_returns_response_on_success(self, mock_get, _): + resp = MagicMock() + resp.status_code = 200 + resp.content = b"fake-image" + mock_get.return_value = resp + result = _safe_request_get("http://example.com/image.jpg") + assert result.status_code == 200 + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + @patch("tensorrt_llm.inputs.utils.requests.get") + def test_follows_valid_redirect(self, mock_get, _): + redirect = MagicMock() + redirect.status_code = 302 + redirect.headers = {"Location": "http://example.com/final.jpg"} + final = MagicMock() + final.status_code = 200 + final.content = b"image-data" + mock_get.side_effect = [redirect, final] + result = _safe_request_get("http://example.com/image.jpg") + assert result.status_code == 200 + assert mock_get.call_count == 2 + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo") + @patch("tensorrt_llm.inputs.utils.requests.get") + def test_raises_on_redirect_to_private_ip(self, mock_get, mock_dns): + def dns_side_effect(host, *args, **kwargs): + return _dns("192.168.1.1") if "evil" in str(host) else PUBLIC_DNS + + mock_dns.side_effect = dns_side_effect + + redirect = MagicMock() + redirect.status_code = 301 + redirect.headers = {"Location": "http://evil.internal/secret"} + mock_get.return_value = redirect + + with pytest.raises(RuntimeError, match="non-public"): + _safe_request_get("http://example.com/image.jpg") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + @patch("tensorrt_llm.inputs.utils.requests.get") + def test_raises_after_too_many_redirects(self, mock_get, _): + redirect = MagicMock() + redirect.status_code = 302 + redirect.headers = {"Location": "http://example.com/image.jpg"} + mock_get.return_value = redirect # always redirect + + with pytest.raises(RuntimeError, match="Too many redirects"): + _safe_request_get("http://example.com/image.jpg") + + # Initial request + _MAX_REDIRECTS follow-ups before giving up + assert mock_get.call_count == _MAX_REDIRECTS + 1 + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + @patch("tensorrt_llm.inputs.utils.requests.get") + def test_raises_on_oversized_response(self, mock_get, _): + resp = MagicMock() + resp.status_code = 200 + resp.content = b"x" * (_MAX_RESPONSE_BYTES + 1) + mock_get.return_value = resp + + with pytest.raises(RuntimeError, match="maximum allowed size"): + _safe_request_get("http://example.com/huge.bin", stream=False) + + def test_rejects_private_url_before_request(self): + """_validate_url() must fire before any requests.get() call.""" + with ( + patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("10.0.0.1")), + patch("tensorrt_llm.inputs.utils.requests.get") as mock_get, + ): + with pytest.raises(RuntimeError, match="non-public"): + _safe_request_get("http://internal.corp/image.jpg") + mock_get.assert_not_called() + + +# --------------------------------------------------------------------------- +# _safe_aiohttp_get +# --------------------------------------------------------------------------- + + +class TestSafeAiohttpGet: + def _run(self, coro): + return asyncio.get_event_loop().run_until_complete(coro) + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + def test_raises_after_too_many_redirects(self, _): + class _FakeResponse: + status = 302 + headers = {"Location": "http://example.com/image.jpg"} + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + pass + + async def raise_for_status(self): + pass + + class _FakeSession: + def get(self, url, **kwargs): + return _FakeResponse() + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + pass + + with patch("tensorrt_llm.inputs.utils.aiohttp.ClientSession", return_value=_FakeSession()): + with pytest.raises(RuntimeError, match="Too many redirects"): + self._run(_safe_aiohttp_get("http://example.com/image.jpg")) + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + def test_raises_on_oversized_response(self, _): + oversized = b"x" * (_MAX_RESPONSE_BYTES + 1) + + class _FakeResponse: + status = 200 + headers = {} + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + pass + + async def raise_for_status(self): + pass + + class content: + @staticmethod + async def read(n): + return oversized[:n] + + class _FakeSession: + def get(self, url, **kwargs): + return _FakeResponse() + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + pass + + with patch("tensorrt_llm.inputs.utils.aiohttp.ClientSession", return_value=_FakeSession()): + with pytest.raises(RuntimeError, match="maximum allowed size"): + self._run(_safe_aiohttp_get("http://example.com/huge.bin")) + + def test_rejects_private_url_before_request(self): + """_validate_url() must fire before any aiohttp call.""" + with ( + patch( + "tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("169.254.169.254") + ), + patch("tensorrt_llm.inputs.utils.aiohttp.ClientSession") as mock_session, + ): + with pytest.raises(RuntimeError, match="non-public"): + self._run(_safe_aiohttp_get("http://169.254.169.254/latest/")) + mock_session.assert_not_called() From 7c1c17fda820ab8e430080c27f312d2b105a9c03 Mon Sep 17 00:00:00 2001 From: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> Date: Sat, 11 Apr 2026 01:09:52 +0000 Subject: [PATCH 5/6] pre-commit Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> --- tensorrt_llm/inputs/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 04e7e26cb878..47374433a863 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -186,7 +186,7 @@ def _safe_request_get(url: str, async def _safe_aiohttp_get(url: str, timeout_sec: int = 30) -> bytes: - """aiohttp GET wrapper that validates every redirect hop before following.""" + """Aiohttp GET wrapper that validates every redirect hop before following.""" _validate_url(url) timeout = aiohttp.ClientTimeout(total=timeout_sec) current_url = url From 21b0fd16909b4c04181aeec260e0542466766ab0 Mon Sep 17 00:00:00 2001 From: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> Date: Tue, 21 Apr 2026 00:56:50 +0000 Subject: [PATCH 6/6] pre-commit Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> --- tensorrt_llm/inputs/utils.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 47374433a863..a4ada329ec14 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -492,13 +492,22 @@ def load_video(video: str, parsed_url = urlparse(video) if parsed_url.scheme in ["http", "https"]: resp = _safe_request_get(video, stream=False) - with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as tmp_file: + with tempfile.NamedTemporaryFile(delete=True, + suffix=".mp4") as tmp_file: tmp_file.write(resp.content) tmp_file.flush() - return _load_video_by_cv2(tmp_file.name, num_frames, fps, format, - device, extract_audio=extract_audio) + return _load_video_by_cv2(tmp_file.name, + num_frames, + fps, + format, + device, + extract_audio=extract_audio) elif parsed_url.scheme in ("", "file"): - return _load_video_by_cv2(video, num_frames, fps, format, device, + return _load_video_by_cv2(video, + num_frames, + fps, + format, + device, extract_audio=extract_audio) elif parsed_url.scheme == "data": decoded_video = load_base64_video(video) @@ -544,8 +553,12 @@ def _load_from_bytes(data: bytes) -> VideoData: decoded_video = load_base64_video(video) return _load_from_bytes(decoded_video) elif parsed_url.scheme in ("", "file"): - return await asyncio.to_thread(_load_video_by_cv2, video, num_frames, - fps, format, device, + return await asyncio.to_thread(_load_video_by_cv2, + video, + num_frames, + fps, + format, + device, extract_audio=extract_audio) else: raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") @@ -594,7 +607,8 @@ async def async_load_audio( audio_data = await _safe_aiohttp_get(audio) audio = BytesIO(audio_data) elif parsed_url.scheme in ("", "file"): - audio = _normalize_file_uri(audio) if parsed_url.scheme == "file" else audio + audio = _normalize_file_uri( + audio) if parsed_url.scheme == "file" else audio else: raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}")