-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[https://nvbugs/5911304][fix] Add URL validation and request hardening for media input loading #12748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[https://nvbugs/5911304][fix] Add URL validation and request hardening for media input loading #12748
Changes from all commits
92cfc16
8c3a795
0791e2c
821f56c
7c1c17f
21b0fd1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,9 @@ | ||
| import asyncio | ||
| import base64 | ||
| import ipaddress | ||
| import math | ||
| import os | ||
| import socket | ||
| import tempfile | ||
| from collections import defaultdict | ||
| from dataclasses import dataclass | ||
|
|
@@ -111,6 +113,100 @@ def convert_image_mode(image: Image.Image, to_mode: str) -> Image.Image: | |
| return image.convert(to_mode) | ||
|
|
||
|
|
||
| # Maximum allowed response size for remote fetches (200 MB). | ||
| _MAX_RESPONSE_BYTES = 200 * 1024 * 1024 | ||
|
|
||
| # Maximum number of redirects allowed for remote fetches. | ||
| _MAX_REDIRECTS = 5 | ||
|
|
||
|
|
||
| def _validate_url(url: str) -> None: | ||
| """Validate that *url* points to a public, non-internal HTTP(S) resource. | ||
|
|
||
| Raises ``RuntimeError`` for URLs that target private, loopback, or | ||
| link-local addresses, or that use a scheme other than http / https. | ||
|
|
||
| Note: validation is performed at DNS-resolution time. A DNS-rebinding | ||
| attack (TTL=0, resolves to a public IP during validation then a private IP | ||
| during the actual TCP connect) could bypass this check. For strict | ||
| isolation, supplement with network-level egress filtering that blocks | ||
| RFC-1918 and APIPA ranges at the host firewall. | ||
| """ | ||
| parsed = urlparse(url) | ||
| if parsed.scheme not in ("http", "https"): | ||
| raise RuntimeError( | ||
| f"Only http and https URLs are allowed, got: {parsed.scheme!r}") | ||
|
|
||
| hostname = parsed.hostname | ||
| if not hostname: | ||
| raise RuntimeError("URL has no hostname") | ||
|
|
||
| # Resolve to IP and check address range. | ||
| try: | ||
| infos = socket.getaddrinfo(hostname, None, proto=socket.IPPROTO_TCP) | ||
| except socket.gaierror as exc: | ||
| raise RuntimeError(f"Could not resolve hostname {hostname!r}") from exc | ||
|
|
||
| for _family, _type, _proto, _canon, sockaddr in infos: | ||
| ip = ipaddress.ip_address(sockaddr[0]) | ||
| if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: | ||
| raise RuntimeError(f"URL resolves to a non-public address ({ip})") | ||
|
|
||
|
|
||
| def _safe_request_get(url: str, | ||
| *, | ||
| stream: bool = False, | ||
| timeout: int = 30) -> "requests.Response": | ||
| """``requests.get`` wrapper that validates the URL first.""" | ||
| _validate_url(url) | ||
| resp = requests.get( | ||
| url, | ||
| stream=stream, | ||
| timeout=timeout, | ||
| allow_redirects=False, | ||
| ) | ||
| for _ in range(_MAX_REDIRECTS): | ||
| if resp.status_code not in (301, 302, 303, 307, 308): | ||
| break | ||
| redirect_url = resp.headers.get("Location", "") | ||
| _validate_url(redirect_url) | ||
| resp.close() | ||
| resp = requests.get( | ||
| redirect_url, | ||
| stream=stream, | ||
| timeout=timeout, | ||
| allow_redirects=False, | ||
| ) | ||
| else: | ||
| raise RuntimeError("Too many redirects") | ||
| resp.raise_for_status() | ||
| if not stream and len(resp.content) > _MAX_RESPONSE_BYTES: | ||
| raise RuntimeError("Response exceeds maximum allowed size") | ||
| return resp | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
|
|
||
| async def _safe_aiohttp_get(url: str, timeout_sec: int = 30) -> bytes: | ||
| """Aiohttp GET wrapper that validates every redirect hop before following.""" | ||
| _validate_url(url) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| timeout = aiohttp.ClientTimeout(total=timeout_sec) | ||
| current_url = url | ||
| async with aiohttp.ClientSession(timeout=timeout) as session: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously the codebase used |
||
| for _ in range(_MAX_REDIRECTS + 1): | ||
|
2ez4bz marked this conversation as resolved.
|
||
| async with session.get(current_url, | ||
| allow_redirects=False) as response: | ||
| if response.status in (301, 302, 303, 307, 308): | ||
| redirect_url = response.headers.get("Location", "") | ||
| _validate_url(redirect_url) | ||
| current_url = redirect_url | ||
| continue | ||
| response.raise_for_status() | ||
| data = await response.content.read(_MAX_RESPONSE_BYTES + 1) | ||
| if len(data) > _MAX_RESPONSE_BYTES: | ||
| raise RuntimeError("Response exceeds maximum allowed size") | ||
| return data | ||
| raise RuntimeError("Too many redirects") | ||
|
|
||
|
|
||
| def _load_and_convert_image(image): | ||
| image = Image.open(image) | ||
| image.load() | ||
|
|
@@ -150,12 +246,14 @@ def load_image(image: Union[str, Image.Image], | |
| parsed_url = urlparse(image) | ||
|
|
||
| if parsed_url.scheme in ["http", "https"]: | ||
| image = requests.get(image, stream=True, timeout=10).raw | ||
| image = _load_and_convert_image(image) | ||
| resp = _safe_request_get(image, stream=True) | ||
| image = _load_and_convert_image(resp.raw) | ||
| elif parsed_url.scheme == "data": | ||
| image = load_base64_image(parsed_url) | ||
| else: | ||
| elif parsed_url.scheme in ("", "file"): | ||
| image = _load_and_convert_image(image) | ||
| else: | ||
| raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") | ||
|
|
||
| if format == "pt": | ||
| return ToTensor()(image).to(device=device) | ||
|
|
@@ -175,16 +273,14 @@ async def async_load_image( | |
| parsed_url = urlparse(image) | ||
|
|
||
| if parsed_url.scheme in ["http", "https"]: | ||
| session = await _get_aiohttp_session() | ||
| async with session.get(image) as response: | ||
| content = await response.read() | ||
| image = await asyncio.to_thread(_load_and_convert_image, | ||
| BytesIO(content)) | ||
| content = await _safe_aiohttp_get(image) | ||
| image = _load_and_convert_image(BytesIO(content)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems we are not using |
||
| elif parsed_url.scheme == "data": | ||
| image = await asyncio.to_thread(load_base64_image, parsed_url) | ||
| image = load_base64_image(parsed_url) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sams as above. |
||
| elif parsed_url.scheme in ("", "file"): | ||
| image = _load_and_convert_image(Path(parsed_url.path)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
| else: | ||
| image = await asyncio.to_thread(_load_and_convert_image, | ||
| Path(parsed_url.path)) | ||
| raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") | ||
|
|
||
| if format == "pt": | ||
| return await asyncio.to_thread(lambda: ToTensor() | ||
|
|
@@ -394,7 +490,19 @@ def load_video(video: str, | |
| device: str = "cpu", | ||
| extract_audio: bool = False) -> VideoData: | ||
| parsed_url = urlparse(video) | ||
| if parsed_url.scheme in ["http", "https", ""]: | ||
| if parsed_url.scheme in ["http", "https"]: | ||
| resp = _safe_request_get(video, stream=False) | ||
| with tempfile.NamedTemporaryFile(delete=True, | ||
| suffix=".mp4") as tmp_file: | ||
| tmp_file.write(resp.content) | ||
| tmp_file.flush() | ||
| return _load_video_by_cv2(tmp_file.name, | ||
| num_frames, | ||
| fps, | ||
| format, | ||
| device, | ||
| extract_audio=extract_audio) | ||
| elif parsed_url.scheme in ("", "file"): | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| return _load_video_by_cv2(video, | ||
| num_frames, | ||
| fps, | ||
|
|
@@ -439,16 +547,21 @@ def _load_from_bytes(data: bytes) -> VideoData: | |
| extract_audio=extract_audio) | ||
|
|
||
| if parsed_url.scheme in ["http", "https"]: | ||
| session = await _get_aiohttp_session() | ||
| async with session.get(video) as response: | ||
| content = await response.content.read() | ||
| return await asyncio.to_thread(_load_from_bytes, content) | ||
| video_data = await _safe_aiohttp_get(video) | ||
| return _load_from_bytes(video_data) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as image case. |
||
| elif parsed_url.scheme == "data": | ||
| decoded_video = load_base64_video(video) | ||
| return await asyncio.to_thread(_load_from_bytes, decoded_video) | ||
| return _load_from_bytes(decoded_video) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
| elif parsed_url.scheme in ("", "file"): | ||
| return await asyncio.to_thread(_load_video_by_cv2, | ||
| video, | ||
| num_frames, | ||
| fps, | ||
| format, | ||
| device, | ||
| extract_audio=extract_audio) | ||
| else: | ||
| return await asyncio.to_thread(_load_video_by_cv2, video, num_frames, | ||
| fps, format, device) | ||
| raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") | ||
|
|
||
|
|
||
| def _normalize_file_uri(uri: str) -> str: | ||
|
|
@@ -466,10 +579,13 @@ def load_audio( | |
| ) -> Tuple[np.ndarray, int]: | ||
| parsed_url = urlparse(audio) | ||
| if parsed_url.scheme in ["http", "https"]: | ||
| audio = requests.get(audio, stream=True, timeout=10) | ||
| audio = BytesIO(audio.content) | ||
| elif parsed_url.scheme == "file": | ||
| audio = _normalize_file_uri(audio) | ||
| resp = _safe_request_get(audio, stream=False) | ||
| audio = BytesIO(resp.content) | ||
| elif parsed_url.scheme in ("", "file"): | ||
| audio = _normalize_file_uri( | ||
| audio) if parsed_url.scheme == "file" else audio | ||
| else: | ||
| raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") | ||
|
|
||
| audio = soundfile.read(audio) | ||
| return audio | ||
|
|
@@ -488,23 +604,22 @@ async def async_load_audio( | |
| parsed_url = urlparse(audio) | ||
|
|
||
| if parsed_url.scheme in ["http", "https"]: | ||
| session = await _get_aiohttp_session() | ||
| async with session.get(audio) as response: | ||
| content = await response.content.read() | ||
| # Offload CPU-bound soundfile decoding to thread pool | ||
| return await asyncio.to_thread(soundfile.read, BytesIO(content)) | ||
| elif parsed_url.scheme == "file": | ||
| audio = _normalize_file_uri(audio) | ||
| audio_data = await _safe_aiohttp_get(audio) | ||
| audio = BytesIO(audio_data) | ||
| elif parsed_url.scheme in ("", "file"): | ||
| audio = _normalize_file_uri( | ||
| audio) if parsed_url.scheme == "file" else audio | ||
| else: | ||
| raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}") | ||
|
|
||
| return await asyncio.to_thread(soundfile.read, audio) | ||
|
|
||
|
|
||
| def encode_base64_content_from_url(content_url: str) -> str: | ||
| """Encode a content retrieved from a remote url to base64 format.""" | ||
|
|
||
| with requests.get(content_url, timeout=10) as response: | ||
| response.raise_for_status() | ||
| result = base64.b64encode(response.content).decode('utf-8') | ||
| resp = _safe_request_get(content_url, stream=False) | ||
| result = base64.b64encode(resp.content).decode('utf-8') | ||
|
|
||
| return result | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.