diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 355e4e0f40ae..a4ada329ec14 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -1,7 +1,9 @@ import asyncio import base64 +import ipaddress import math import os +import socket import tempfile from collections import defaultdict from dataclasses import dataclass @@ -111,6 +113,100 @@ def convert_image_mode(image: Image.Image, to_mode: str) -> Image.Image: return image.convert(to_mode) +# Maximum allowed response size for remote fetches (200 MB). +_MAX_RESPONSE_BYTES = 200 * 1024 * 1024 + +# Maximum number of redirects allowed for remote fetches. +_MAX_REDIRECTS = 5 + + +def _validate_url(url: str) -> None: + """Validate that *url* points to a public, non-internal HTTP(S) resource. + + Raises ``RuntimeError`` for URLs that target private, loopback, or + link-local addresses, or that use a scheme other than http / https. + + Note: validation is performed at DNS-resolution time. A DNS-rebinding + attack (TTL=0, resolves to a public IP during validation then a private IP + during the actual TCP connect) could bypass this check. For strict + isolation, supplement with network-level egress filtering that blocks + RFC-1918 and APIPA ranges at the host firewall. + """ + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + raise RuntimeError( + f"Only http and https URLs are allowed, got: {parsed.scheme!r}") + + hostname = parsed.hostname + if not hostname: + raise RuntimeError("URL has no hostname") + + # Resolve to IP and check address range. + try: + infos = socket.getaddrinfo(hostname, None, proto=socket.IPPROTO_TCP) + except socket.gaierror as exc: + raise RuntimeError(f"Could not resolve hostname {hostname!r}") from exc + + for _family, _type, _proto, _canon, sockaddr in infos: + ip = ipaddress.ip_address(sockaddr[0]) + if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: + raise RuntimeError(f"URL resolves to a non-public address ({ip})") + + +def _safe_request_get(url: str, + *, + stream: bool = False, + timeout: int = 30) -> "requests.Response": + """``requests.get`` wrapper that validates the URL first.""" + _validate_url(url) + resp = requests.get( + url, + stream=stream, + timeout=timeout, + allow_redirects=False, + ) + for _ in range(_MAX_REDIRECTS): + if resp.status_code not in (301, 302, 303, 307, 308): + break + redirect_url = resp.headers.get("Location", "") + _validate_url(redirect_url) + resp.close() + resp = requests.get( + redirect_url, + stream=stream, + timeout=timeout, + allow_redirects=False, + ) + else: + raise RuntimeError("Too many redirects") + resp.raise_for_status() + if not stream and len(resp.content) > _MAX_RESPONSE_BYTES: + raise RuntimeError("Response exceeds maximum allowed size") + return resp + + +async def _safe_aiohttp_get(url: str, timeout_sec: int = 30) -> bytes: + """Aiohttp GET wrapper that validates every redirect hop before following.""" + _validate_url(url) + timeout = aiohttp.ClientTimeout(total=timeout_sec) + current_url = url + async with aiohttp.ClientSession(timeout=timeout) as session: + for _ in range(_MAX_REDIRECTS + 1): + async with session.get(current_url, + allow_redirects=False) as response: + if response.status in (301, 302, 303, 307, 308): + redirect_url = response.headers.get("Location", "") + _validate_url(redirect_url) + current_url = redirect_url + continue + response.raise_for_status() + data = await response.content.read(_MAX_RESPONSE_BYTES + 1) + if len(data) > _MAX_RESPONSE_BYTES: + raise RuntimeError("Response exceeds maximum allowed size") + return data + raise RuntimeError("Too many redirects") + + def _load_and_convert_image(image): image = Image.open(image) image.load() @@ -150,12 +246,14 @@ def load_image(image: Union[str, Image.Image], parsed_url = urlparse(image) if parsed_url.scheme in ["http", "https"]: - image = requests.get(image, stream=True, timeout=10).raw - image = _load_and_convert_image(image) + resp = _safe_request_get(image, stream=True) + image = _load_and_convert_image(resp.raw) elif parsed_url.scheme == "data": image = load_base64_image(parsed_url) - else: + elif parsed_url.scheme in ("", "file"): image = _load_and_convert_image(image) + else: + raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") if format == "pt": return ToTensor()(image).to(device=device) @@ -175,16 +273,14 @@ async def async_load_image( parsed_url = urlparse(image) if parsed_url.scheme in ["http", "https"]: - session = await _get_aiohttp_session() - async with session.get(image) as response: - content = await response.read() - image = await asyncio.to_thread(_load_and_convert_image, - BytesIO(content)) + content = await _safe_aiohttp_get(image) + image = _load_and_convert_image(BytesIO(content)) elif parsed_url.scheme == "data": - image = await asyncio.to_thread(load_base64_image, parsed_url) + image = load_base64_image(parsed_url) + elif parsed_url.scheme in ("", "file"): + image = _load_and_convert_image(Path(parsed_url.path)) else: - image = await asyncio.to_thread(_load_and_convert_image, - Path(parsed_url.path)) + raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") if format == "pt": return await asyncio.to_thread(lambda: ToTensor() @@ -394,7 +490,19 @@ def load_video(video: str, device: str = "cpu", extract_audio: bool = False) -> VideoData: parsed_url = urlparse(video) - if parsed_url.scheme in ["http", "https", ""]: + if parsed_url.scheme in ["http", "https"]: + resp = _safe_request_get(video, stream=False) + with tempfile.NamedTemporaryFile(delete=True, + suffix=".mp4") as tmp_file: + tmp_file.write(resp.content) + tmp_file.flush() + return _load_video_by_cv2(tmp_file.name, + num_frames, + fps, + format, + device, + extract_audio=extract_audio) + elif parsed_url.scheme in ("", "file"): return _load_video_by_cv2(video, num_frames, fps, @@ -439,16 +547,21 @@ def _load_from_bytes(data: bytes) -> VideoData: extract_audio=extract_audio) if parsed_url.scheme in ["http", "https"]: - session = await _get_aiohttp_session() - async with session.get(video) as response: - content = await response.content.read() - return await asyncio.to_thread(_load_from_bytes, content) + video_data = await _safe_aiohttp_get(video) + return _load_from_bytes(video_data) elif parsed_url.scheme == "data": decoded_video = load_base64_video(video) - return await asyncio.to_thread(_load_from_bytes, decoded_video) + return _load_from_bytes(decoded_video) + elif parsed_url.scheme in ("", "file"): + return await asyncio.to_thread(_load_video_by_cv2, + video, + num_frames, + fps, + format, + device, + extract_audio=extract_audio) else: - return await asyncio.to_thread(_load_video_by_cv2, video, num_frames, - fps, format, device) + raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") def _normalize_file_uri(uri: str) -> str: @@ -466,10 +579,13 @@ def load_audio( ) -> Tuple[np.ndarray, int]: parsed_url = urlparse(audio) if parsed_url.scheme in ["http", "https"]: - audio = requests.get(audio, stream=True, timeout=10) - audio = BytesIO(audio.content) - elif parsed_url.scheme == "file": - audio = _normalize_file_uri(audio) + resp = _safe_request_get(audio, stream=False) + audio = BytesIO(resp.content) + elif parsed_url.scheme in ("", "file"): + audio = _normalize_file_uri( + audio) if parsed_url.scheme == "file" else audio + else: + raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") audio = soundfile.read(audio) return audio @@ -488,13 +604,13 @@ async def async_load_audio( parsed_url = urlparse(audio) if parsed_url.scheme in ["http", "https"]: - session = await _get_aiohttp_session() - async with session.get(audio) as response: - content = await response.content.read() - # Offload CPU-bound soundfile decoding to thread pool - return await asyncio.to_thread(soundfile.read, BytesIO(content)) - elif parsed_url.scheme == "file": - audio = _normalize_file_uri(audio) + audio_data = await _safe_aiohttp_get(audio) + audio = BytesIO(audio_data) + elif parsed_url.scheme in ("", "file"): + audio = _normalize_file_uri( + audio) if parsed_url.scheme == "file" else audio + else: + raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") return await asyncio.to_thread(soundfile.read, audio) @@ -502,9 +618,8 @@ async def async_load_audio( def encode_base64_content_from_url(content_url: str) -> str: """Encode a content retrieved from a remote url to base64 format.""" - with requests.get(content_url, timeout=10) as response: - response.raise_for_status() - result = base64.b64encode(response.content).decode('utf-8') + resp = _safe_request_get(content_url, stream=False) + result = base64.b64encode(resp.content).decode('utf-8') return result diff --git a/tests/unittest/inputs/test_url_validation.py b/tests/unittest/inputs/test_url_validation.py new file mode 100644 index 000000000000..65c1f7ff07f9 --- /dev/null +++ b/tests/unittest/inputs/test_url_validation.py @@ -0,0 +1,266 @@ +"""Unit tests for SSRF-prevention URL validation helpers in inputs/utils.py. + +Tests cover _validate_url(), _safe_request_get(), and _safe_aiohttp_get() +without making real network connections. +""" + +import asyncio +import socket +from unittest.mock import MagicMock, patch + +import pytest + +from tensorrt_llm.inputs.utils import ( + _MAX_REDIRECTS, + _MAX_RESPONSE_BYTES, + _safe_aiohttp_get, + _safe_request_get, + _validate_url, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _dns(ip: str): + """Return a minimal getaddrinfo result that resolves to *ip*.""" + return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (ip, 0))] + + +PUBLIC_DNS = _dns("93.184.216.34") # example.com + + +# --------------------------------------------------------------------------- +# _validate_url +# --------------------------------------------------------------------------- + + +class TestValidateUrl: + def test_rejects_file_scheme(self): + with pytest.raises(RuntimeError, match="Only http"): + _validate_url("file:///etc/passwd") + + def test_rejects_ftp_scheme(self): + with pytest.raises(RuntimeError, match="Only http"): + _validate_url("ftp://example.com/file") + + def test_rejects_missing_hostname(self): + with pytest.raises(RuntimeError, match="no hostname"): + _validate_url("http:///path") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("127.0.0.1")) + def test_rejects_loopback_ipv4(self, _): + with pytest.raises(RuntimeError, match="non-public"): + _validate_url("http://localhost/") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("::1")) + def test_rejects_loopback_ipv6(self, _): + with pytest.raises(RuntimeError, match="non-public"): + _validate_url("http://ip6-localhost/") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("10.0.0.1")) + def test_rejects_rfc1918_10(self, _): + with pytest.raises(RuntimeError, match="non-public"): + _validate_url("http://internal.corp/") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("172.16.0.1")) + def test_rejects_rfc1918_172(self, _): + with pytest.raises(RuntimeError, match="non-public"): + _validate_url("http://vpn.example.com/") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("192.168.1.100")) + def test_rejects_rfc1918_192(self, _): + with pytest.raises(RuntimeError, match="non-public"): + _validate_url("http://router.local/") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("169.254.169.254")) + def test_rejects_cloud_metadata_imds(self, _): + """AWS / Azure / GCP instance metadata service must be blocked.""" + with pytest.raises(RuntimeError, match="non-public"): + _validate_url("http://169.254.169.254/latest/meta-data/") + + @patch( + "tensorrt_llm.inputs.utils.socket.getaddrinfo", + side_effect=socket.gaierror("Name or service not known"), + ) + def test_rejects_unresolvable_hostname(self, _): + with pytest.raises(RuntimeError, match="Could not resolve"): + _validate_url("http://this.does.not.exist.invalid/") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + def test_accepts_public_hostname(self, _): + _validate_url("http://example.com/image.jpg") # must not raise + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + def test_accepts_https(self, _): + _validate_url("https://example.com/image.jpg") # must not raise + + +# --------------------------------------------------------------------------- +# _safe_request_get +# --------------------------------------------------------------------------- + + +class TestSafeRequestGet: + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + @patch("tensorrt_llm.inputs.utils.requests.get") + def test_returns_response_on_success(self, mock_get, _): + resp = MagicMock() + resp.status_code = 200 + resp.content = b"fake-image" + mock_get.return_value = resp + result = _safe_request_get("http://example.com/image.jpg") + assert result.status_code == 200 + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + @patch("tensorrt_llm.inputs.utils.requests.get") + def test_follows_valid_redirect(self, mock_get, _): + redirect = MagicMock() + redirect.status_code = 302 + redirect.headers = {"Location": "http://example.com/final.jpg"} + final = MagicMock() + final.status_code = 200 + final.content = b"image-data" + mock_get.side_effect = [redirect, final] + result = _safe_request_get("http://example.com/image.jpg") + assert result.status_code == 200 + assert mock_get.call_count == 2 + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo") + @patch("tensorrt_llm.inputs.utils.requests.get") + def test_raises_on_redirect_to_private_ip(self, mock_get, mock_dns): + def dns_side_effect(host, *args, **kwargs): + return _dns("192.168.1.1") if "evil" in str(host) else PUBLIC_DNS + + mock_dns.side_effect = dns_side_effect + + redirect = MagicMock() + redirect.status_code = 301 + redirect.headers = {"Location": "http://evil.internal/secret"} + mock_get.return_value = redirect + + with pytest.raises(RuntimeError, match="non-public"): + _safe_request_get("http://example.com/image.jpg") + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + @patch("tensorrt_llm.inputs.utils.requests.get") + def test_raises_after_too_many_redirects(self, mock_get, _): + redirect = MagicMock() + redirect.status_code = 302 + redirect.headers = {"Location": "http://example.com/image.jpg"} + mock_get.return_value = redirect # always redirect + + with pytest.raises(RuntimeError, match="Too many redirects"): + _safe_request_get("http://example.com/image.jpg") + + # Initial request + _MAX_REDIRECTS follow-ups before giving up + assert mock_get.call_count == _MAX_REDIRECTS + 1 + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + @patch("tensorrt_llm.inputs.utils.requests.get") + def test_raises_on_oversized_response(self, mock_get, _): + resp = MagicMock() + resp.status_code = 200 + resp.content = b"x" * (_MAX_RESPONSE_BYTES + 1) + mock_get.return_value = resp + + with pytest.raises(RuntimeError, match="maximum allowed size"): + _safe_request_get("http://example.com/huge.bin", stream=False) + + def test_rejects_private_url_before_request(self): + """_validate_url() must fire before any requests.get() call.""" + with ( + patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("10.0.0.1")), + patch("tensorrt_llm.inputs.utils.requests.get") as mock_get, + ): + with pytest.raises(RuntimeError, match="non-public"): + _safe_request_get("http://internal.corp/image.jpg") + mock_get.assert_not_called() + + +# --------------------------------------------------------------------------- +# _safe_aiohttp_get +# --------------------------------------------------------------------------- + + +class TestSafeAiohttpGet: + def _run(self, coro): + return asyncio.get_event_loop().run_until_complete(coro) + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + def test_raises_after_too_many_redirects(self, _): + class _FakeResponse: + status = 302 + headers = {"Location": "http://example.com/image.jpg"} + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + pass + + async def raise_for_status(self): + pass + + class _FakeSession: + def get(self, url, **kwargs): + return _FakeResponse() + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + pass + + with patch("tensorrt_llm.inputs.utils.aiohttp.ClientSession", return_value=_FakeSession()): + with pytest.raises(RuntimeError, match="Too many redirects"): + self._run(_safe_aiohttp_get("http://example.com/image.jpg")) + + @patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS) + def test_raises_on_oversized_response(self, _): + oversized = b"x" * (_MAX_RESPONSE_BYTES + 1) + + class _FakeResponse: + status = 200 + headers = {} + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + pass + + async def raise_for_status(self): + pass + + class content: + @staticmethod + async def read(n): + return oversized[:n] + + class _FakeSession: + def get(self, url, **kwargs): + return _FakeResponse() + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + pass + + with patch("tensorrt_llm.inputs.utils.aiohttp.ClientSession", return_value=_FakeSession()): + with pytest.raises(RuntimeError, match="maximum allowed size"): + self._run(_safe_aiohttp_get("http://example.com/huge.bin")) + + def test_rejects_private_url_before_request(self): + """_validate_url() must fire before any aiohttp call.""" + with ( + patch( + "tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("169.254.169.254") + ), + patch("tensorrt_llm.inputs.utils.aiohttp.ClientSession") as mock_session, + ): + with pytest.raises(RuntimeError, match="non-public"): + self._run(_safe_aiohttp_get("http://169.254.169.254/latest/")) + mock_session.assert_not_called()