Skip to content

Commit 57e03ab

Browse files
committed
align TensorRTPyBenchmark error handling
Signed-off-by: Will Guo <willg@nvidia.com>
1 parent d74bb08 commit 57e03ab

1 file changed

Lines changed: 55 additions & 33 deletions

File tree

modelopt/onnx/quantization/autotune/benchmark.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __call__(self, path_or_bytes: str | bytes, log_file: str | None = None) -> f
134134
log_file: Optional path to save benchmark logs
135135
136136
Returns:
137-
Measured latency in milliseconds
137+
Measured latency in milliseconds, or float("inf") on failure.
138138
"""
139139
return self.run(path_or_bytes, log_file)
140140

@@ -526,6 +526,32 @@ def _build_engine(
526526
finally:
527527
del parser, network, config
528528

529+
@staticmethod
530+
def _alloc_pinned_host(size: int, dtype: np.dtype) -> tuple[Any, np.ndarray, Any]:
531+
"""Allocate pinned host memory and return (host_ptr, array view, cuda error).
532+
533+
Returns:
534+
(host_ptr, arr, err): On success err is cudaSuccess; on failure host_ptr/arr
535+
may be None and err is the CUDA error code.
536+
"""
537+
nbytes = size * np.dtype(dtype).itemsize
538+
err, host_ptr = cudart.cudaMallocHost(nbytes)
539+
if err != cudart.cudaError_t.cudaSuccess:
540+
return (None, None, err)
541+
addr = int(host_ptr) if hasattr(host_ptr, "__int__") else host_ptr
542+
ctype = np.ctypeslib.as_ctypes_type(dtype)
543+
arr = np.ctypeslib.as_array((ctype * size).from_address(addr))
544+
return (host_ptr, arr, cudart.cudaError_t.cudaSuccess)
545+
546+
@staticmethod
547+
def _free_buffers(bufs: list[dict]) -> None:
548+
"""Free host and device memory for a list of buffer dicts (host_ptr, device_ptr)."""
549+
for buf in bufs:
550+
if "host_ptr" in buf and buf["host_ptr"] is not None:
551+
cudart.cudaFreeHost(buf["host_ptr"])
552+
if "device_ptr" in buf and buf["device_ptr"] is not None:
553+
cudart.cudaFree(buf["device_ptr"])
554+
529555
def _allocate_buffers(
530556
self,
531557
engine: "trt.ICudaEngine",
@@ -538,25 +564,11 @@ def _allocate_buffers(
538564
context: Execution context with tensor shapes set.
539565
540566
Returns:
541-
(inputs, outputs, stream_handle) where inputs/outputs are lists of buffer dicts
542-
with keys host_ptr, host, device_ptr, nbytes, name; stream_handle is a CUDA stream.
543-
544-
Raises:
545-
RuntimeError: If CUDA allocation or stream creation fails.
567+
(inputs, outputs, cuda_error): On success cuda_error is cudaSuccess;
568+
on failure inputs/outputs are empty and cuda_error is the failing CUDA error code.
546569
"""
547-
548-
def _alloc_pinned_host(size: int, dtype: np.dtype):
549-
nbytes = size * np.dtype(dtype).itemsize
550-
err, host_ptr = cudart.cudaMallocHost(nbytes)
551-
if err != cudart.cudaError_t.cudaSuccess:
552-
raise RuntimeError(f"cudaMallocHost failed: {err}")
553-
addr = int(host_ptr) if hasattr(host_ptr, "__int__") else host_ptr
554-
ctype = np.ctypeslib.as_ctypes_type(dtype)
555-
arr = np.ctypeslib.as_array((ctype * size).from_address(addr))
556-
return host_ptr, arr
557-
558-
inputs = []
559-
outputs = []
570+
inputs: list[dict] = []
571+
outputs: list[dict] = []
560572

561573
for i in range(engine.num_io_tensors):
562574
tensor_name = engine.get_tensor_name(i)
@@ -568,9 +580,16 @@ def _alloc_pinned_host(size: int, dtype: np.dtype):
568580

569581
err, device_ptr = cudart.cudaMalloc(nbytes)
570582
if err != cudart.cudaError_t.cudaSuccess:
571-
raise RuntimeError(f"cudaMalloc failed: {err}")
583+
self.logger.error(f"cudaMalloc failed: {err}")
584+
self._free_buffers(inputs + outputs)
585+
return ([], [], err)
572586

573-
host_ptr, host_mem = _alloc_pinned_host(size, dtype)
587+
host_ptr, host_mem, err = self._alloc_pinned_host(size, dtype)
588+
if err != cudart.cudaError_t.cudaSuccess:
589+
self.logger.error(f"cudaMallocHost failed: {err}")
590+
cudart.cudaFree(device_ptr)
591+
self._free_buffers(inputs + outputs)
592+
return ([], [], err)
574593

575594
if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
576595
np.copyto(host_mem, np.random.randn(size).astype(dtype))
@@ -596,11 +615,7 @@ def _alloc_pinned_host(size: int, dtype: np.dtype):
596615

597616
context.set_tensor_address(tensor_name, int(device_ptr))
598617

599-
err, stream_handle = cudart.cudaStreamCreate()
600-
if err != cudart.cudaError_t.cudaSuccess:
601-
raise RuntimeError(f"cudaStreamCreate failed: {err}")
602-
603-
return (inputs, outputs, stream_handle)
618+
return (inputs, outputs, cudart.cudaError_t.cudaSuccess)
604619

605620
def _setup_execution_context(
606621
self, serialized_engine: bytes
@@ -691,7 +706,8 @@ def run(
691706
flush_timing_cache: If True, save the timing cache to disk after engine build.
692707
693708
Returns:
694-
Measured median latency in milliseconds
709+
Measured median latency in milliseconds, or float("inf") on any error
710+
(e.g. build failure, deserialization failure, buffer/stream allocation failure).
695711
"""
696712
serialized_engine = engine = context = stream_handle = None
697713
inputs, outputs = [], []
@@ -705,7 +721,17 @@ def run(
705721
if engine is None or context is None:
706722
return float("inf")
707723

708-
inputs, outputs, stream_handle = self._allocate_buffers(engine, context)
724+
inputs, outputs, alloc_err = self._allocate_buffers(engine, context)
725+
if alloc_err != cudart.cudaError_t.cudaSuccess:
726+
self.logger.error(f"Buffer allocation failed: {alloc_err}")
727+
return float("inf")
728+
729+
err, sh = cudart.cudaStreamCreate()
730+
if err != cudart.cudaError_t.cudaSuccess:
731+
self.logger.error(f"cudaStreamCreate failed: {err}")
732+
return float("inf")
733+
stream_handle = sh
734+
709735
self._run_warmup(context, inputs, outputs, stream_handle)
710736
latencies = self._run_timing(context, inputs, outputs, stream_handle)
711737

@@ -750,11 +776,7 @@ def run(
750776
return float("inf")
751777
finally:
752778
try:
753-
for buf in inputs + outputs:
754-
if "host_ptr" in buf:
755-
cudart.cudaFreeHost(buf["host_ptr"])
756-
if "device_ptr" in buf:
757-
cudart.cudaFree(buf["device_ptr"])
779+
self._free_buffers(inputs + outputs)
758780
if stream_handle is not None:
759781
cudart.cudaStreamDestroy(stream_handle)
760782
del (

0 commit comments

Comments
 (0)