@@ -134,7 +134,7 @@ def __call__(self, path_or_bytes: str | bytes, log_file: str | None = None) -> f
134134 log_file: Optional path to save benchmark logs
135135
136136 Returns:
137- Measured latency in milliseconds
137+ Measured latency in milliseconds, or float("inf") on failure.
138138 """
139139 return self .run (path_or_bytes , log_file )
140140
@@ -526,6 +526,32 @@ def _build_engine(
526526 finally :
527527 del parser , network , config
528528
529+ @staticmethod
530+ def _alloc_pinned_host (size : int , dtype : np .dtype ) -> tuple [Any , np .ndarray , Any ]:
531+ """Allocate pinned host memory and return (host_ptr, array view, cuda error).
532+
533+ Returns:
534+ (host_ptr, arr, err): On success err is cudaSuccess; on failure host_ptr/arr
535+ may be None and err is the CUDA error code.
536+ """
537+ nbytes = size * np .dtype (dtype ).itemsize
538+ err , host_ptr = cudart .cudaMallocHost (nbytes )
539+ if err != cudart .cudaError_t .cudaSuccess :
540+ return (None , None , err )
541+ addr = int (host_ptr ) if hasattr (host_ptr , "__int__" ) else host_ptr
542+ ctype = np .ctypeslib .as_ctypes_type (dtype )
543+ arr = np .ctypeslib .as_array ((ctype * size ).from_address (addr ))
544+ return (host_ptr , arr , cudart .cudaError_t .cudaSuccess )
545+
546+ @staticmethod
547+ def _free_buffers (bufs : list [dict ]) -> None :
548+ """Free host and device memory for a list of buffer dicts (host_ptr, device_ptr)."""
549+ for buf in bufs :
550+ if "host_ptr" in buf and buf ["host_ptr" ] is not None :
551+ cudart .cudaFreeHost (buf ["host_ptr" ])
552+ if "device_ptr" in buf and buf ["device_ptr" ] is not None :
553+ cudart .cudaFree (buf ["device_ptr" ])
554+
529555 def _allocate_buffers (
530556 self ,
531557 engine : "trt.ICudaEngine" ,
@@ -538,25 +564,11 @@ def _allocate_buffers(
538564 context: Execution context with tensor shapes set.
539565
540566 Returns:
541- (inputs, outputs, stream_handle) where inputs/outputs are lists of buffer dicts
542- with keys host_ptr, host, device_ptr, nbytes, name; stream_handle is a CUDA stream.
543-
544- Raises:
545- RuntimeError: If CUDA allocation or stream creation fails.
567+ (inputs, outputs, cuda_error): On success cuda_error is cudaSuccess;
568+ on failure inputs/outputs are empty and cuda_error is the failing CUDA error code.
546569 """
547-
548- def _alloc_pinned_host (size : int , dtype : np .dtype ):
549- nbytes = size * np .dtype (dtype ).itemsize
550- err , host_ptr = cudart .cudaMallocHost (nbytes )
551- if err != cudart .cudaError_t .cudaSuccess :
552- raise RuntimeError (f"cudaMallocHost failed: { err } " )
553- addr = int (host_ptr ) if hasattr (host_ptr , "__int__" ) else host_ptr
554- ctype = np .ctypeslib .as_ctypes_type (dtype )
555- arr = np .ctypeslib .as_array ((ctype * size ).from_address (addr ))
556- return host_ptr , arr
557-
558- inputs = []
559- outputs = []
570+ inputs : list [dict ] = []
571+ outputs : list [dict ] = []
560572
561573 for i in range (engine .num_io_tensors ):
562574 tensor_name = engine .get_tensor_name (i )
@@ -568,9 +580,16 @@ def _alloc_pinned_host(size: int, dtype: np.dtype):
568580
569581 err , device_ptr = cudart .cudaMalloc (nbytes )
570582 if err != cudart .cudaError_t .cudaSuccess :
571- raise RuntimeError (f"cudaMalloc failed: { err } " )
583+ self .logger .error (f"cudaMalloc failed: { err } " )
584+ self ._free_buffers (inputs + outputs )
585+ return ([], [], err )
572586
573- host_ptr , host_mem = _alloc_pinned_host (size , dtype )
587+ host_ptr , host_mem , err = self ._alloc_pinned_host (size , dtype )
588+ if err != cudart .cudaError_t .cudaSuccess :
589+ self .logger .error (f"cudaMallocHost failed: { err } " )
590+ cudart .cudaFree (device_ptr )
591+ self ._free_buffers (inputs + outputs )
592+ return ([], [], err )
574593
575594 if engine .get_tensor_mode (tensor_name ) == trt .TensorIOMode .INPUT :
576595 np .copyto (host_mem , np .random .randn (size ).astype (dtype ))
@@ -596,11 +615,7 @@ def _alloc_pinned_host(size: int, dtype: np.dtype):
596615
597616 context .set_tensor_address (tensor_name , int (device_ptr ))
598617
599- err , stream_handle = cudart .cudaStreamCreate ()
600- if err != cudart .cudaError_t .cudaSuccess :
601- raise RuntimeError (f"cudaStreamCreate failed: { err } " )
602-
603- return (inputs , outputs , stream_handle )
618+ return (inputs , outputs , cudart .cudaError_t .cudaSuccess )
604619
605620 def _setup_execution_context (
606621 self , serialized_engine : bytes
@@ -691,7 +706,8 @@ def run(
691706 flush_timing_cache: If True, save the timing cache to disk after engine build.
692707
693708 Returns:
694- Measured median latency in milliseconds
709+ Measured median latency in milliseconds, or float("inf") on any error
710+ (e.g. build failure, deserialization failure, buffer/stream allocation failure).
695711 """
696712 serialized_engine = engine = context = stream_handle = None
697713 inputs , outputs = [], []
@@ -705,7 +721,17 @@ def run(
705721 if engine is None or context is None :
706722 return float ("inf" )
707723
708- inputs , outputs , stream_handle = self ._allocate_buffers (engine , context )
724+ inputs , outputs , alloc_err = self ._allocate_buffers (engine , context )
725+ if alloc_err != cudart .cudaError_t .cudaSuccess :
726+ self .logger .error (f"Buffer allocation failed: { alloc_err } " )
727+ return float ("inf" )
728+
729+ err , sh = cudart .cudaStreamCreate ()
730+ if err != cudart .cudaError_t .cudaSuccess :
731+ self .logger .error (f"cudaStreamCreate failed: { err } " )
732+ return float ("inf" )
733+ stream_handle = sh
734+
709735 self ._run_warmup (context , inputs , outputs , stream_handle )
710736 latencies = self ._run_timing (context , inputs , outputs , stream_handle )
711737
@@ -750,11 +776,7 @@ def run(
750776 return float ("inf" )
751777 finally :
752778 try :
753- for buf in inputs + outputs :
754- if "host_ptr" in buf :
755- cudart .cudaFreeHost (buf ["host_ptr" ])
756- if "device_ptr" in buf :
757- cudart .cudaFree (buf ["device_ptr" ])
779+ self ._free_buffers (inputs + outputs )
758780 if stream_handle is not None :
759781 cudart .cudaStreamDestroy (stream_handle )
760782 del (
0 commit comments