Skip to content

Commit 0ad287c

Browse files
authored
[ONNX][Autotune] Replace CUDA memory management from CUDART to PyTorch (#998)
### What does this PR do? **Type of change**: Bug fix **Overview**: Replace CUDA memory management from CUDART to PyTorch (higher-level API). ### Usage ```python # Add a code snippet demonstrating how to use this ``` ### Testing 1. Added unittests. 2. Tested that this PR does not break #951 or #978 ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, using `torch.load(..., weights_only=True)`, avoiding `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other source, did you follow IP policy in [CONTRIBUTING.md](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md#-copying-code-from-other-sources)?: N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information Summary of changes in `benchmark.py — TensorRTPyBenchmark`: | What changed | Before | After | |---|---|---| | Imports | `contextlib` + `from cuda.bindings import runtime as cudart` | `import torch` (conditional) | | Availability flag | `CUDART_AVAILABLE` | `TORCH_CUDA_AVAILABLE = torch.cuda.is_available()` | | `__init__` guard | checks `CUDART_AVAILABLE or cudart is None` | checks `TORCH_CUDA_AVAILABLE` | | `_alloc_pinned_host` | `cudaMallocHost` + ctypes address hack, returns `(ptr, arr, err)` | `torch.empty(...).pin_memory()`, returns `(tensor, tensor.numpy())` | | `_free_buffers` | `cudaFreeHost` + `cudaFree` per buffer | `bufs.clear()` — PyTorch GC handles deallocation | | `_allocate_buffers` | raw `device_ptr` integers, error-code returns | `torch.empty(..., device="cuda")`, `tensor.data_ptr()` for TRT address | | `_run_warmup` | `cudaMemcpyAsync` + `cudaStreamSynchronize` | `tensor.copy_(non_blocking=True)` inside `torch.cuda.stream()` | | `_run_timing` | same cudart pattern | same torch pattern | | `run` — stream lifecycle | `cudaStreamCreate()` / `cudaStreamDestroy()` | `torch.cuda.Stream()` / `del stream` | | `run` — stream arg to TRT | raw integer handle | `stream.cuda_stream` (integer property) | | Error handling | `cudaError_t` return codes | PyTorch raises `RuntimeError`, caught by existing `except Exception` | Related to #961 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * TensorRT benchmarking migrated from direct CUDA runtime calls to PyTorch CUDA tensors, pinned memory, and CUDA stream primitives — simplifying buffer management, transfers, and timing semantics. * **Tests** * Expanded GPU autotune benchmark tests with broader unit and integration coverage for CUDA/TensorRT paths, pinned-host/device buffering, stream behavior, warmup/timing, and end-to-end latency scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com>
1 parent d3748c2 commit 0ad287c

2 files changed

Lines changed: 366 additions & 160 deletions

File tree

modelopt/onnx/quantization/autotune/benchmark.py

Lines changed: 56 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
- TensorRTPyBenchmark: Uses TensorRT Python API for direct engine profiling
2727
"""
2828

29-
import contextlib
3029
import ctypes
3130
import importlib.util
3231
import os
@@ -40,6 +39,7 @@
4039
from typing import Any
4140

4241
import numpy as np
42+
import torch
4343

4444
from modelopt.onnx.logging_config import logger
4545
from modelopt.onnx.quantization.ort_utils import _check_for_tensorrt
@@ -48,13 +48,7 @@
4848
if TRT_AVAILABLE:
4949
import tensorrt as trt
5050

51-
CUDART_AVAILABLE = importlib.util.find_spec("cuda") is not None
52-
if CUDART_AVAILABLE:
53-
try:
54-
from cuda.bindings import runtime as cudart
55-
except ImportError:
56-
with contextlib.suppress(ImportError):
57-
from cuda import cudart # deprecated: prefer cuda.bindings.runtime
51+
TORCH_CUDA_AVAILABLE = torch.cuda.is_available()
5852

5953

6054
def _validate_shape_range(min_shape: list, opt_shape: list, max_shape: list) -> None:
@@ -337,17 +331,17 @@ def __init__(
337331
engine building. If None, no custom plugins are loaded.
338332
339333
Raises:
340-
ImportError: If tensorrt or cuda-python (cudart) packages are not available.
334+
ImportError: If tensorrt is not installed or if torch is not built with CUDA support.
341335
FileNotFoundError: If a specified plugin library file does not exist.
342336
RuntimeError: If plugin library loading fails.
343337
"""
344338
super().__init__(timing_cache_file, warmup_runs, timing_runs, plugin_libraries)
345339

346340
if not TRT_AVAILABLE:
347341
raise ImportError("TensorRT Python API not available. Please install tensorrt package.")
348-
if not CUDART_AVAILABLE or cudart is None:
342+
if not TORCH_CUDA_AVAILABLE:
349343
raise ImportError(
350-
"CUDA Runtime (cudart) not available. Please install cuda-python package: pip install cuda-python"
344+
"PyTorch with CUDA support not available. Please install torch with CUDA: pip install torch"
351345
)
352346

353347
self.trt_logger = trt.Logger(trt.Logger.WARNING)
@@ -527,107 +521,57 @@ def _build_engine(
527521
del parser, network, config
528522

529523
@staticmethod
530-
def _alloc_pinned_host(size: int, dtype: np.dtype) -> tuple[Any, np.ndarray, Any]:
531-
"""Allocate pinned host memory and return (host_ptr, array view, cuda error).
524+
def _alloc_pinned_host(size: int, dtype: np.dtype) -> tuple[Any, np.ndarray]:
525+
"""Allocate pinned host memory using PyTorch and return (tensor, numpy_view).
532526
533527
Returns:
534-
(host_ptr, arr, err): On success err is cudaSuccess; on failure host_ptr/arr
535-
may be None and err is the CUDA error code.
528+
(host_tensor, arr): Pinned PyTorch tensor and a numpy view over it.
536529
"""
537-
dtype = np.dtype(dtype)
538-
nbytes = size * dtype.itemsize
539-
err, host_ptr = cudart.cudaMallocHost(nbytes)
540-
if err != cudart.cudaError_t.cudaSuccess:
541-
return (None, None, err)
542-
addr = int(host_ptr) if hasattr(host_ptr, "__int__") else host_ptr
543-
try:
544-
ctype = np.ctypeslib.as_ctypes_type(dtype)
545-
arr = np.ctypeslib.as_array((ctype * size).from_address(addr))
546-
except NotImplementedError as e:
547-
# float16/bfloat16 have no ctypes equivalent; use same-size type and view
548-
if dtype.itemsize == 2:
549-
ctype = ctypes.c_uint16
550-
else:
551-
raise TypeError(
552-
f"Pinned host allocation for dtype {dtype} is not supported: "
553-
"no ctypes mapping and no fallback for this itemsize"
554-
) from e
555-
arr = np.ctypeslib.as_array((ctype * size).from_address(addr)).view(dtype)
556-
return (host_ptr, arr, cudart.cudaError_t.cudaSuccess)
530+
torch_dtype = torch.from_numpy(np.empty(0, dtype=dtype)).dtype
531+
host_tensor = torch.empty(int(size), dtype=torch_dtype).pin_memory()
532+
return host_tensor, host_tensor.numpy()
557533

558534
@staticmethod
559535
def _free_buffers(bufs: list[dict]) -> None:
560-
"""Free host and device memory for a list of buffer dicts (host_ptr, device_ptr)."""
561-
for buf in bufs:
562-
if "host_ptr" in buf and buf["host_ptr"] is not None:
563-
cudart.cudaFreeHost(buf["host_ptr"])
564-
if "device_ptr" in buf and buf["device_ptr"] is not None:
565-
cudart.cudaFree(buf["device_ptr"])
536+
"""Release buffer references; PyTorch handles underlying memory deallocation."""
537+
bufs.clear()
566538

567539
def _allocate_buffers(
568540
self,
569541
engine: "trt.ICudaEngine",
570542
context: "trt.IExecutionContext",
571-
) -> tuple[list[dict], list[dict], Any]:
572-
"""Allocate host and device buffers for engine I/O and set tensor addresses.
543+
) -> tuple[list[dict], list[dict]]:
544+
"""Allocate pinned host and device tensors for engine I/O and set tensor addresses.
573545
574546
Args:
575547
engine: Deserialized TensorRT engine.
576548
context: Execution context with tensor shapes set.
577549
578550
Returns:
579-
(inputs, outputs, cuda_error): On success cuda_error is cudaSuccess;
580-
on failure inputs/outputs are empty and cuda_error is the failing CUDA error code.
551+
(inputs, outputs): Lists of buffer dicts containing PyTorch tensors.
581552
"""
582553
inputs: list[dict] = []
583554
outputs: list[dict] = []
584555

585556
for i in range(engine.num_io_tensors):
586557
tensor_name = engine.get_tensor_name(i)
587-
dtype = trt.nptype(engine.get_tensor_dtype(tensor_name))
558+
np_dtype = trt.nptype(engine.get_tensor_dtype(tensor_name))
588559
shape = context.get_tensor_shape(tensor_name)
560+
size = int(trt.volume(shape))
589561

590-
size = trt.volume(shape)
591-
nbytes = size * np.dtype(dtype).itemsize
562+
host_tensor, host_mem = self._alloc_pinned_host(size, np_dtype)
563+
torch_dtype = torch.from_numpy(np.empty(0, dtype=np_dtype)).dtype
564+
device_tensor = torch.empty(size, dtype=torch_dtype, device="cuda")
592565

593-
err, device_ptr = cudart.cudaMalloc(nbytes)
594-
if err != cudart.cudaError_t.cudaSuccess:
595-
self.logger.error(f"cudaMalloc failed: {err}")
596-
self._free_buffers(inputs + outputs)
597-
return ([], [], err)
598-
599-
host_ptr, host_mem, err = self._alloc_pinned_host(size, dtype)
600-
if err != cudart.cudaError_t.cudaSuccess:
601-
self.logger.error(f"cudaMallocHost failed: {err}")
602-
cudart.cudaFree(device_ptr)
603-
self._free_buffers(inputs + outputs)
604-
return ([], [], err)
566+
context.set_tensor_address(tensor_name, device_tensor.data_ptr())
605567

606568
if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
607-
np.copyto(host_mem, np.random.randn(size).astype(dtype))
608-
inputs.append(
609-
{
610-
"host_ptr": host_ptr,
611-
"host": host_mem,
612-
"device_ptr": device_ptr,
613-
"nbytes": nbytes,
614-
"name": tensor_name,
615-
}
616-
)
569+
np.copyto(host_mem, np.random.randn(size).astype(np_dtype))
570+
inputs.append({"host": host_tensor, "device": device_tensor, "name": tensor_name})
617571
else:
618-
outputs.append(
619-
{
620-
"host_ptr": host_ptr,
621-
"host": host_mem,
622-
"device_ptr": device_ptr,
623-
"nbytes": nbytes,
624-
"name": tensor_name,
625-
}
626-
)
627-
628-
context.set_tensor_address(tensor_name, int(device_ptr))
572+
outputs.append({"host": host_tensor, "device": device_tensor, "name": tensor_name})
629573

630-
return (inputs, outputs, cudart.cudaError_t.cudaSuccess)
574+
return (inputs, outputs)
631575

632576
def _setup_execution_context(
633577
self, serialized_engine: bytes
@@ -652,55 +596,44 @@ def _run_warmup(
652596
context: "trt.IExecutionContext",
653597
inputs: list[dict],
654598
outputs: list[dict],
655-
stream_handle: Any,
599+
stream: "torch.cuda.Stream",
656600
) -> None:
657601
"""Run warmup iterations to stabilize GPU state and cache."""
658-
h2d = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
659-
d2h = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
660602
self.logger.debug(f"Running {self.warmup_runs} warmup iterations...")
661-
for _ in range(self.warmup_runs):
662-
for inp in inputs:
663-
cudart.cudaMemcpyAsync(
664-
inp["device_ptr"], inp["host_ptr"], inp["nbytes"], h2d, stream_handle
665-
)
666-
context.execute_async_v3(stream_handle)
667-
for out in outputs:
668-
cudart.cudaMemcpyAsync(
669-
out["host_ptr"], out["device_ptr"], out["nbytes"], d2h, stream_handle
670-
)
671-
cudart.cudaStreamSynchronize(stream_handle)
603+
with torch.cuda.stream(stream):
604+
for _ in range(self.warmup_runs):
605+
for inp in inputs:
606+
inp["device"].copy_(inp["host"], non_blocking=True)
607+
context.execute_async_v3(stream.cuda_stream)
608+
for out in outputs:
609+
out["host"].copy_(out["device"], non_blocking=True)
610+
stream.synchronize()
672611

673612
def _run_timing(
674613
self,
675614
context: "trt.IExecutionContext",
676615
inputs: list[dict],
677616
outputs: list[dict],
678-
stream_handle: Any,
617+
stream: "torch.cuda.Stream",
679618
) -> np.ndarray:
680619
"""Run timing iterations and return per-run latencies in milliseconds."""
681-
h2d = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
682-
d2h = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
683620
self.logger.debug(f"Running {self.timing_runs} timing iterations...")
684621
latencies = []
685-
for _ in range(self.timing_runs):
686-
for inp in inputs:
687-
cudart.cudaMemcpyAsync(
688-
inp["device_ptr"], inp["host_ptr"], inp["nbytes"], h2d, stream_handle
689-
)
622+
with torch.cuda.stream(stream):
623+
for _ in range(self.timing_runs):
624+
for inp in inputs:
625+
inp["device"].copy_(inp["host"], non_blocking=True)
690626

691-
cudart.cudaStreamSynchronize(stream_handle)
692-
start = time.perf_counter()
693-
context.execute_async_v3(stream_handle)
694-
cudart.cudaStreamSynchronize(stream_handle)
695-
end = time.perf_counter()
627+
stream.synchronize()
628+
start = time.perf_counter()
629+
context.execute_async_v3(stream.cuda_stream)
630+
stream.synchronize()
631+
end = time.perf_counter()
696632

697-
latency_ms = (end - start) * 1000.0
698-
latencies.append(latency_ms)
633+
latencies.append((end - start) * 1000.0)
699634

700-
for out in outputs:
701-
cudart.cudaMemcpyAsync(
702-
out["host_ptr"], out["device_ptr"], out["nbytes"], d2h, stream_handle
703-
)
635+
for out in outputs:
636+
out["host"].copy_(out["device"], non_blocking=True)
704637

705638
return np.array(latencies)
706639

@@ -721,7 +654,7 @@ def run(
721654
Measured median latency in milliseconds, or float("inf") on any error
722655
(e.g. build failure, deserialization failure, buffer/stream allocation failure).
723656
"""
724-
serialized_engine = engine = context = stream_handle = None
657+
serialized_engine = engine = context = stream = None
725658
inputs, outputs = [], []
726659

727660
try:
@@ -733,19 +666,11 @@ def run(
733666
if engine is None or context is None:
734667
return float("inf")
735668

736-
inputs, outputs, alloc_err = self._allocate_buffers(engine, context)
737-
if alloc_err != cudart.cudaError_t.cudaSuccess:
738-
self.logger.error(f"Buffer allocation failed: {alloc_err}")
739-
return float("inf")
669+
inputs, outputs = self._allocate_buffers(engine, context)
670+
stream = torch.cuda.Stream()
740671

741-
err, sh = cudart.cudaStreamCreate()
742-
if err != cudart.cudaError_t.cudaSuccess:
743-
self.logger.error(f"cudaStreamCreate failed: {err}")
744-
return float("inf")
745-
stream_handle = sh
746-
747-
self._run_warmup(context, inputs, outputs, stream_handle)
748-
latencies = self._run_timing(context, inputs, outputs, stream_handle)
672+
self._run_warmup(context, inputs, outputs, stream)
673+
latencies = self._run_timing(context, inputs, outputs, stream)
749674

750675
median_latency = float(np.median(latencies))
751676
mean_latency = float(np.mean(latencies))
@@ -788,17 +713,9 @@ def run(
788713
return float("inf")
789714
finally:
790715
try:
791-
self._free_buffers(inputs + outputs)
792-
if stream_handle is not None:
793-
cudart.cudaStreamDestroy(stream_handle)
794-
del (
795-
inputs,
796-
outputs,
797-
stream_handle,
798-
context,
799-
engine,
800-
serialized_engine,
801-
)
716+
self._free_buffers(inputs)
717+
self._free_buffers(outputs)
718+
del inputs, outputs, stream, context, engine, serialized_engine
802719
except Exception as cleanup_error:
803720
self.logger.warning(f"Error during cleanup: {cleanup_error}")
804721

0 commit comments

Comments
 (0)