Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 148 additions & 33 deletions tensorrt_llm/inputs/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import base64
import ipaddress
import math
import os
import socket
import tempfile
from collections import defaultdict
from dataclasses import dataclass
Expand Down Expand Up @@ -111,6 +113,100 @@ def convert_image_mode(image: Image.Image, to_mode: str) -> Image.Image:
return image.convert(to_mode)


# Maximum allowed response size for remote fetches (200 MB).
_MAX_RESPONSE_BYTES = 200 * 1024 * 1024

# Maximum number of redirects allowed for remote fetches.
_MAX_REDIRECTS = 5


def _validate_url(url: str) -> None:
"""Validate that *url* points to a public, non-internal HTTP(S) resource.

Raises ``RuntimeError`` for URLs that target private, loopback, or
link-local addresses, or that use a scheme other than http / https.

Note: validation is performed at DNS-resolution time. A DNS-rebinding
attack (TTL=0, resolves to a public IP during validation then a private IP
during the actual TCP connect) could bypass this check. For strict
isolation, supplement with network-level egress filtering that blocks
RFC-1918 and APIPA ranges at the host firewall.
"""
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
raise RuntimeError(
f"Only http and https URLs are allowed, got: {parsed.scheme!r}")

hostname = parsed.hostname
if not hostname:
raise RuntimeError("URL has no hostname")

# Resolve to IP and check address range.
try:
infos = socket.getaddrinfo(hostname, None, proto=socket.IPPROTO_TCP)
except socket.gaierror as exc:
raise RuntimeError(f"Could not resolve hostname {hostname!r}") from exc

for _family, _type, _proto, _canon, sockaddr in infos:
ip = ipaddress.ip_address(sockaddr[0])
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
raise RuntimeError(f"URL resolves to a non-public address ({ip})")


def _safe_request_get(url: str,
*,
stream: bool = False,
timeout: int = 30) -> "requests.Response":
"""``requests.get`` wrapper that validates the URL first."""
_validate_url(url)
resp = requests.get(
url,
stream=stream,
timeout=timeout,
allow_redirects=False,
)
for _ in range(_MAX_REDIRECTS):
if resp.status_code not in (301, 302, 303, 307, 308):
Comment thread
2ez4bz marked this conversation as resolved.
break
redirect_url = resp.headers.get("Location", "")
_validate_url(redirect_url)
resp.close()
resp = requests.get(
redirect_url,
stream=stream,
timeout=timeout,
allow_redirects=False,
)
else:
raise RuntimeError("Too many redirects")
resp.raise_for_status()
if not stream and len(resp.content) > _MAX_RESPONSE_BYTES:
raise RuntimeError("Response exceeds maximum allowed size")
return resp
Comment thread
coderabbitai[bot] marked this conversation as resolved.


async def _safe_aiohttp_get(url: str, timeout_sec: int = 30) -> bytes:
"""Aiohttp GET wrapper that validates every redirect hop before following."""
_validate_url(url)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_validate_url internally calls socket.getaddrinfo(), which is synchronous and can block the event loop for seconds on slow DNS. Since this function is async and intended for use in concurrent request paths, consider wrapping it with await asyncio.to_thread(_validate_url, url).

timeout = aiohttp.ClientTimeout(total=timeout_sec)
current_url = url
async with aiohttp.ClientSession(timeout=timeout) as session:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously the codebase used _get_aiohttp_session(), please consider accepting an existing session as a parameter.

for _ in range(_MAX_REDIRECTS + 1):
Comment thread
2ez4bz marked this conversation as resolved.
async with session.get(current_url,
allow_redirects=False) as response:
if response.status in (301, 302, 303, 307, 308):
redirect_url = response.headers.get("Location", "")
_validate_url(redirect_url)
current_url = redirect_url
continue
response.raise_for_status()
data = await response.content.read(_MAX_RESPONSE_BYTES + 1)
if len(data) > _MAX_RESPONSE_BYTES:
raise RuntimeError("Response exceeds maximum allowed size")
return data
raise RuntimeError("Too many redirects")


def _load_and_convert_image(image):
image = Image.open(image)
image.load()
Expand Down Expand Up @@ -150,12 +246,14 @@ def load_image(image: Union[str, Image.Image],
parsed_url = urlparse(image)

if parsed_url.scheme in ["http", "https"]:
image = requests.get(image, stream=True, timeout=10).raw
image = _load_and_convert_image(image)
resp = _safe_request_get(image, stream=True)
image = _load_and_convert_image(resp.raw)
elif parsed_url.scheme == "data":
image = load_base64_image(parsed_url)
else:
elif parsed_url.scheme in ("", "file"):
image = _load_and_convert_image(image)
else:
raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}")

if format == "pt":
return ToTensor()(image).to(device=device)
Expand All @@ -175,16 +273,14 @@ async def async_load_image(
parsed_url = urlparse(image)

if parsed_url.scheme in ["http", "https"]:
session = await _get_aiohttp_session()
async with session.get(image) as response:
content = await response.read()
image = await asyncio.to_thread(_load_and_convert_image,
BytesIO(content))
content = await _safe_aiohttp_get(image)
image = _load_and_convert_image(BytesIO(content))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we are not using async, can you go back to original code where we used asyncio.to_thread?

elif parsed_url.scheme == "data":
image = await asyncio.to_thread(load_base64_image, parsed_url)
image = load_base64_image(parsed_url)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sams as above.

elif parsed_url.scheme in ("", "file"):
image = _load_and_convert_image(Path(parsed_url.path))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

else:
image = await asyncio.to_thread(_load_and_convert_image,
Path(parsed_url.path))
raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}")

if format == "pt":
return await asyncio.to_thread(lambda: ToTensor()
Expand Down Expand Up @@ -394,7 +490,19 @@ def load_video(video: str,
device: str = "cpu",
extract_audio: bool = False) -> VideoData:
parsed_url = urlparse(video)
if parsed_url.scheme in ["http", "https", ""]:
if parsed_url.scheme in ["http", "https"]:
resp = _safe_request_get(video, stream=False)
with tempfile.NamedTemporaryFile(delete=True,
suffix=".mp4") as tmp_file:
tmp_file.write(resp.content)
tmp_file.flush()
return _load_video_by_cv2(tmp_file.name,
num_frames,
fps,
format,
device,
extract_audio=extract_audio)
elif parsed_url.scheme in ("", "file"):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return _load_video_by_cv2(video,
num_frames,
fps,
Expand Down Expand Up @@ -439,16 +547,21 @@ def _load_from_bytes(data: bytes) -> VideoData:
extract_audio=extract_audio)

if parsed_url.scheme in ["http", "https"]:
session = await _get_aiohttp_session()
async with session.get(video) as response:
content = await response.content.read()
return await asyncio.to_thread(_load_from_bytes, content)
video_data = await _safe_aiohttp_get(video)
return _load_from_bytes(video_data)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as image case.

elif parsed_url.scheme == "data":
decoded_video = load_base64_video(video)
return await asyncio.to_thread(_load_from_bytes, decoded_video)
return _load_from_bytes(decoded_video)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

elif parsed_url.scheme in ("", "file"):
return await asyncio.to_thread(_load_video_by_cv2,
video,
num_frames,
fps,
format,
device,
extract_audio=extract_audio)
else:
return await asyncio.to_thread(_load_video_by_cv2, video, num_frames,
fps, format, device)
raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}")


def _normalize_file_uri(uri: str) -> str:
Expand All @@ -466,10 +579,13 @@ def load_audio(
) -> Tuple[np.ndarray, int]:
parsed_url = urlparse(audio)
if parsed_url.scheme in ["http", "https"]:
audio = requests.get(audio, stream=True, timeout=10)
audio = BytesIO(audio.content)
elif parsed_url.scheme == "file":
audio = _normalize_file_uri(audio)
resp = _safe_request_get(audio, stream=False)
audio = BytesIO(resp.content)
elif parsed_url.scheme in ("", "file"):
audio = _normalize_file_uri(
audio) if parsed_url.scheme == "file" else audio
else:
raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}")

audio = soundfile.read(audio)
return audio
Expand All @@ -488,23 +604,22 @@ async def async_load_audio(
parsed_url = urlparse(audio)

if parsed_url.scheme in ["http", "https"]:
session = await _get_aiohttp_session()
async with session.get(audio) as response:
content = await response.content.read()
# Offload CPU-bound soundfile decoding to thread pool
return await asyncio.to_thread(soundfile.read, BytesIO(content))
elif parsed_url.scheme == "file":
audio = _normalize_file_uri(audio)
audio_data = await _safe_aiohttp_get(audio)
audio = BytesIO(audio_data)
elif parsed_url.scheme in ("", "file"):
audio = _normalize_file_uri(
audio) if parsed_url.scheme == "file" else audio
else:
raise ValueError(f"Unsupported URL scheme: {parsed_url.scheme!r}")

return await asyncio.to_thread(soundfile.read, audio)


def encode_base64_content_from_url(content_url: str) -> str:
"""Encode a content retrieved from a remote url to base64 format."""

with requests.get(content_url, timeout=10) as response:
response.raise_for_status()
result = base64.b64encode(response.content).decode('utf-8')
resp = _safe_request_get(content_url, stream=False)
result = base64.b64encode(resp.content).decode('utf-8')

return result

Expand Down
Loading
Loading