-
Notifications
You must be signed in to change notification settings - Fork 382
Expand file tree
/
Copy pathbenchmark.py
More file actions
759 lines (647 loc) · 31.7 KB
/
benchmark.py
File metadata and controls
759 lines (647 loc) · 31.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorRT Utilities and Benchmark Module.
This benchmark module is used to evaluate ONNX model performance.
It provides comprehensive TensorRT utilities including:
- Benchmark framework for measuring TensorRT engine performance
- Graph utilities for tensor analysis
**Benchmark Classes:**
- Benchmark: Abstract base class defining the benchmarking interface
- TrtExecBenchmark: Uses trtexec command-line tool for benchmarking
- TensorRTPyBenchmark: Uses TensorRT Python API for direct engine profiling
"""
import ctypes
import importlib.util
import os
import re
import shutil
import tempfile
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import numpy as np
import torch
from modelopt.onnx.logging_config import logger
from modelopt.onnx.quantization.ort_utils import _check_for_trtexec, _run_trtexec
TRT_AVAILABLE = importlib.util.find_spec("tensorrt") is not None
if TRT_AVAILABLE:
import tensorrt as trt
TORCH_CUDA_AVAILABLE = torch.cuda.is_available()
def _validate_shape_range(min_shape: list, opt_shape: list, max_shape: list) -> None:
"""Raise ValueError if shape lengths differ or if min <= opt <= max fails at any dimension."""
if len(min_shape) != len(opt_shape) or len(opt_shape) != len(max_shape):
raise ValueError("min_shape, opt_shape, and max_shape must have the same length")
for i, (min_d, opt_d, max_d) in enumerate(zip(min_shape, opt_shape, max_shape)):
if min_d > opt_d or opt_d > max_d:
raise ValueError(
f"Invalid shape range at dimension {i}: "
f"min={min_d}, opt={opt_d}, max={max_d}. "
f"Must satisfy min <= opt <= max"
)
class Benchmark(ABC):
"""Abstract base class for TensorRT model benchmarking.
This class defines the interface that all benchmark implementations must follow.
It provides a consistent API for measuring inference latency of ONNX models
when converted to TensorRT engines.
Attributes:
timing_cache_file: Path to the TensorRT timing cache file.
warmup_runs: Number of warmup iterations before timing.
timing_runs: Number of iterations for latency measurement.
plugin_libraries: List of paths to plugin libraries.
logger: Logger instance for this benchmark.
Subclasses must implement:
run(): Execute the benchmark and return latency in milliseconds.
"""
def __init__(
self,
timing_cache_file: str | None = None,
warmup_runs: int = 5,
timing_runs: int = 10,
plugin_libraries: list[str] | None = None,
):
"""Initialize the benchmark.
Args:
timing_cache_file: Path to timing cache file to accelerate engine builds.
If None, uses '/tmp/trtexec_timing.cache' as default.
warmup_runs: Number of warmup iterations before timing measurements.
timing_runs: Number of iterations for latency measurement. Results
are averaged across these runs.
plugin_libraries: List of paths to TensorRT plugin shared libraries (.so files).
These plugins will be loaded during engine building.
If None, no custom plugins are loaded.
"""
self.timing_cache_file = timing_cache_file or "/tmp/trtexec_timing.cache" # nosec B108
self.warmup_runs = warmup_runs
self.timing_runs = timing_runs
self.plugin_libraries = plugin_libraries or []
self.logger = logger
@abstractmethod
def run(self, path_or_bytes: str | bytes, log_file: str | None = None) -> float:
"""Run benchmark on the given ONNX model.
Args:
path_or_bytes: Path to the ONNX model (str) or raw model data (bytes)
log_file: Optional path to save benchmark logs
Returns:
Measured latency in milliseconds, or float("inf") on failure
"""
raise NotImplementedError("Subclasses must implement this method")
def __call__(self, path_or_bytes: str | bytes, log_file: str | None = None) -> float:
"""Convenience method to call benchmark as a function.
Args:
path_or_bytes: Path to the ONNX model (str) or raw model data (bytes)
log_file: Optional path to save benchmark logs
Returns:
Measured latency in milliseconds, or float("inf") on failure.
"""
return self.run(path_or_bytes, log_file)
def _write_log_file(self, file: Path | str | None, content: str) -> None:
if file is None:
return
if isinstance(file, str):
file = Path(file)
try:
file.parent.mkdir(parents=True, exist_ok=True)
file.write_text(content)
self.logger.debug(f"Saved logs to: {file}")
except Exception as e:
self.logger.warning(f"Failed to save logs to {file}: {e}")
class TrtExecBenchmark(Benchmark):
"""TensorRT benchmark using trtexec command-line tool.
This implementation uses the trtexec binary to build engines and measure
inference latency. It is the most straightforward method and closely
mirrors standard TensorRT workflows.
"""
def __init__(
self,
timing_cache_file: str | None = None,
warmup_runs: int = 5,
timing_runs: int = 10,
plugin_libraries: list[str] | None = None,
trtexec_path: str = "trtexec",
trtexec_args: list[str] | None = None,
):
"""Initialize the trtexec benchmark.
Args:
timing_cache_file: See :meth:`Benchmark.__init__`.
warmup_runs: See :meth:`Benchmark.__init__`.
timing_runs: See :meth:`Benchmark.__init__`.
plugin_libraries: See :meth:`Benchmark.__init__`.
trtexec_path: Path to trtexec binary. Defaults to 'trtexec' which
looks for the binary in PATH.
trtexec_args: Additional command-line arguments to pass to trtexec.
These are appended after the standard arguments.
Example: ['--fp16', '--workspace=4096', '--verbose']
"""
super().__init__(timing_cache_file, warmup_runs, timing_runs, plugin_libraries)
self.trtexec_path = trtexec_path
self.trtexec_args = trtexec_args if trtexec_args is not None else []
self.temp_dir = tempfile.mkdtemp(prefix="trtexec_benchmark_")
self.engine_path = os.path.join(self.temp_dir, "engine.trt")
self.temp_model_path = os.path.join(self.temp_dir, "temp_model.onnx")
self.logger.debug(f"Created temporary engine directory: {self.temp_dir}")
self.logger.debug(f"Temporary model path: {self.temp_model_path}")
self.latency_pattern = r"\[I\]\s+Latency:.*?median\s*=\s*([\d.]+)\s*ms"
self._base_cmd = [
self.trtexec_path,
f"--avgRuns={self.timing_runs}",
f"--iterations={self.timing_runs}",
f"--warmUp={self.warmup_runs}",
"--stronglyTyped",
f"--saveEngine={self.engine_path}",
f"--timingCacheFile={self.timing_cache_file}",
]
for plugin_lib in self.plugin_libraries:
plugin_path = Path(plugin_lib).resolve()
if not plugin_path.exists():
self.logger.warning(f"Plugin library not found: {plugin_path}")
continue
self._base_cmd.append(f"--staticPlugins={plugin_path}")
self.logger.debug(f"Added plugin library: {plugin_path}")
trtexec_args = self.trtexec_args or []
has_remote_config = any("--remoteAutoTuningConfig" in arg for arg in trtexec_args)
if has_remote_config:
try:
_check_for_trtexec(min_version="10.15")
self.logger.debug("TensorRT Python API version >= 10.15 detected")
if "--safe" not in trtexec_args:
self.logger.warning(
"Remote autotuning requires '--safe' to be set. Adding it to trtexec arguments."
)
self.trtexec_args.append("--safe")
if "--skipInference" not in trtexec_args:
self.logger.warning(
"Remote autotuning requires '--skipInference' to be set. Adding it to trtexec arguments."
)
self.trtexec_args.append("--skipInference")
return
except ImportError:
self.logger.warning(
"Remote autotuning is not supported with TensorRT version < 10.15. "
"Removing --remoteAutoTuningConfig from trtexec arguments"
)
trtexec_args = [
arg for arg in trtexec_args if "--remoteAutoTuningConfig" not in arg
]
self._base_cmd.extend(trtexec_args)
self.logger.debug(f"Base command template: {' '.join(self._base_cmd)}")
def __del__(self):
"""Cleanup temporary directory."""
if hasattr(self, "temp_dir"):
try:
shutil.rmtree(self.temp_dir, ignore_errors=True)
self.logger.debug(f"Cleaned up temporary directory: {self.temp_dir}")
except Exception as e:
self.logger.warning(f"Failed to cleanup temporary directory: {e}")
def run(
self,
path_or_bytes: str | bytes,
log_file: str | None = None,
flush_timing_cache: bool = False,
) -> float:
"""Run benchmark using trtexec.
Args:
path_or_bytes: Path to the ONNX model (str) or raw model data (bytes)
log_file: Optional path to save trtexec logs
Returns:
Measured median latency in milliseconds
"""
if not os.path.exists(self.timing_cache_file):
self.logger.debug(f"Will create timing cache: {self.timing_cache_file}")
try:
model_path = path_or_bytes
if isinstance(model_path, bytes):
with open(self.temp_model_path, "wb") as f:
f.write(model_path)
model_path = self.temp_model_path
self.logger.debug(f"Wrote model bytes to temporary file: {model_path}")
cmd = [*self._base_cmd, f"--onnx={model_path}"]
self.logger.debug(f"Running: {' '.join(cmd)}")
result = _run_trtexec(cmd)
self._write_log_file(
log_file,
"\n".join(
[
f"Command: {' '.join(cmd)}",
f"Return code: {result.returncode}",
"=" * 80,
"STDOUT:",
"=" * 80,
result.stdout,
"\n" + "=" * 80,
"STDERR:",
"=" * 80,
result.stderr,
"\n" + "=" * 80,
]
),
)
if result.returncode != 0:
self.logger.error(f"trtexec failed with return code {result.returncode}")
self.logger.error(f"stderr: {result.stderr}")
return float("inf")
if not (match := re.search(self.latency_pattern, result.stdout, re.IGNORECASE)):
self.logger.warning("Could not parse median latency from trtexec output")
self.logger.debug(f"trtexec stdout:\n{result.stdout}")
return float("inf")
latency = float(match.group(1))
self.logger.info(f"TrtExec benchmark (median): {latency:.2f} ms")
return latency
except FileNotFoundError:
self.logger.error(f"trtexec binary not found: {self.trtexec_path}")
self.logger.error("Please ensure TensorRT is installed and trtexec path is correct")
return float("inf")
except Exception as e:
self.logger.error(f"Benchmark failed: {e}")
return float("inf")
class TensorRTPyBenchmark(Benchmark):
"""TensorRT benchmark using Python API with plugin support.
This implementation directly uses the TensorRT Python API to build engines
and measure inference latency. It provides more control than trtexec and
can be faster for certain workflows as it avoids subprocess overhead.
"""
def __init__(
self,
timing_cache_file: str | None = None,
warmup_runs: int = 5,
timing_runs: int = 20,
plugin_libraries: list[str] | None = None,
):
"""Initialize the TensorRT Python API benchmark.
Creates persistent TensorRT objects (Logger, Builder, Runtime) and
loads the timing cache from disk if available. Optionally loads custom
TensorRT plugin libraries for models with custom operations.
Args:
timing_cache_file: Path to TensorRT timing cache file. If None,
defaults to '/tmp/trtexec_timing.cache'.
warmup_runs: Number of warmup iterations before timing measurements.
timing_runs: Number of iterations for latency measurement.
plugin_libraries: List of paths to TensorRT plugin shared libraries (.so files).
These plugins will be loaded and registered for use during
engine building. If None, no custom plugins are loaded.
Raises:
ImportError: If tensorrt is not installed or if torch is not built with CUDA support.
FileNotFoundError: If a specified plugin library file does not exist.
RuntimeError: If plugin library loading fails.
"""
super().__init__(timing_cache_file, warmup_runs, timing_runs, plugin_libraries)
if not TRT_AVAILABLE:
raise ImportError("TensorRT Python API not available. Please install tensorrt package.")
if not TORCH_CUDA_AVAILABLE:
raise ImportError(
"PyTorch with CUDA support not available. Please install torch with CUDA: pip install torch"
)
self.trt_logger = trt.Logger(trt.Logger.WARNING)
self.builder = trt.Builder(self.trt_logger)
self.runtime = trt.Runtime(self.trt_logger)
self._loaded_plugin_handles = []
if self.plugin_libraries:
self._load_plugin_libraries()
trt.init_libnvinfer_plugins(self.trt_logger, "")
self._plugin_registry = trt.get_plugin_registry()
self.network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
self.network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
self._timing_cache = None
self._load_timing_cache()
self._shape_configs = {}
def _load_plugin_libraries(self):
"""Load custom TensorRT plugin libraries from shared object files.
This method loads plugin libraries using ctypes and initializes them
with the TensorRT plugin registry. Plugins must export the
initLibNvInferPlugins function to register their implementations.
The loaded library handles are stored to prevent them from being
garbage collected during the benchmark's lifetime.
Raises:
FileNotFoundError: If a plugin library file does not exist.
RuntimeError: If plugin initialization fails.
"""
for plugin_lib in self.plugin_libraries:
plugin_path = Path(plugin_lib).resolve()
if not plugin_path.exists():
raise FileNotFoundError(f"Plugin library not found: {plugin_path}")
self.logger.info(f"Loading TensorRT plugin: {plugin_path}")
try:
if hasattr(os, "RTLD_LAZY") and hasattr(os, "RTLD_GLOBAL"):
plugin_handle = ctypes.CDLL(
str(plugin_path), mode=os.RTLD_LAZY | os.RTLD_GLOBAL
)
else:
# Fallback for platforms without RTLD flags (e.g., Windows)
plugin_handle = ctypes.CDLL(str(plugin_path))
# Store handle to prevent garbage collection
self._loaded_plugin_handles.append(plugin_handle)
# Try to initialize plugin with TensorRT registry
# Most TensorRT plugins export initLibNvInferPlugins function
if hasattr(plugin_handle, "initLibNvInferPlugins"):
init_func = plugin_handle.initLibNvInferPlugins
# Function signature: bool initLibNvInferPlugins(void* logger, const char* namespace)
init_func.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
init_func.restype = ctypes.c_bool
# Initialize with the TensorRT logger and default namespace
success = init_func(None, b"")
if not success:
self.logger.warning(
f"Plugin initialization returned false for: {plugin_path}"
)
else:
self.logger.info(f"Successfully initialized plugin: {plugin_path.name}")
else:
self.logger.info(
f"Plugin loaded (no initLibNvInferPlugins function): {plugin_path.name}"
)
except Exception as e:
raise RuntimeError(f"Failed to load plugin library {plugin_path}: {e}") from e
def set_shapes(self, input_name: str, min_shape: list, opt_shape: list, max_shape: list):
"""Set custom min/opt/max shapes for a dynamic input.
This method allows you to specify custom shape ranges for dynamic inputs
(inputs with -1 dimensions). If not specified, the benchmark will use
default shapes (all -1 dimensions become 1).
Args:
input_name: Name of the input tensor to configure.
min_shape: Minimum shape for this input. List of integers.
opt_shape: Optimal/default shape for this input. List of integers.
max_shape: Maximum shape for this input. List of integers.
"""
_validate_shape_range(min_shape, opt_shape, max_shape)
self._shape_configs[input_name] = (min_shape, opt_shape, max_shape)
self.logger.debug(
f"Set shapes for input '{input_name}': "
f"min={min_shape}, opt={opt_shape}, max={max_shape}"
)
def _build_engine(
self,
path_or_bytes: str | bytes,
flush_timing_cache: bool = False,
) -> tuple[bytes | None, float | None]:
"""Build a serialized TensorRT engine from ONNX model data.
Args:
path_or_bytes: Path to the ONNX model (str) or raw model data (bytes).
flush_timing_cache: If True, save the timing cache to disk after build.
Returns:
(serialized_engine, build_time) on success, or (None, None) on parse/build failure.
"""
config = self.builder.create_builder_config()
network = self.builder.create_network(self.network_flags)
parser = trt.OnnxParser(network, self.trt_logger)
try:
config.set_flag(trt.BuilderFlag.DIRECT_IO)
if not config.set_timing_cache(self._timing_cache, ignore_mismatch=True):
self.logger.warning("Failed to set timing cache to builder config")
if isinstance(path_or_bytes, bytes):
self.logger.debug(f"Parsing ONNX model from bytes (size: {len(path_or_bytes)})")
model_data = path_or_bytes
else:
self.logger.debug(f"Parsing ONNX model: {path_or_bytes}")
with open(path_or_bytes, "rb") as f:
model_data = f.read()
if not parser.parse(model_data):
self.logger.error("Failed to parse ONNX model")
for error_idx in range(parser.num_errors):
self.logger.error(f" {parser.get_error(error_idx)}")
return (None, None)
has_dynamic_shapes = any(
any(dim == -1 for dim in network.get_input(i).shape)
for i in range(network.num_inputs)
)
if has_dynamic_shapes:
profile = self.builder.create_optimization_profile()
for i in range(network.num_inputs):
input_tensor = network.get_input(i)
input_name = input_tensor.name
shape = list(input_tensor.shape)
if input_name in self._shape_configs:
min_shape, opt_shape, max_shape = self._shape_configs[input_name]
self.logger.debug(
f"Using custom shapes for input '{input_name}': "
f"min={min_shape}, opt={opt_shape}, max={max_shape}"
)
else:
min_shape = [1 if dim == -1 else dim for dim in shape]
opt_shape = [1 if dim == -1 else dim for dim in shape]
max_shape = [1 if dim == -1 else dim for dim in shape]
self.logger.debug(
f"Using default shapes for input '{input_name}': {opt_shape}"
)
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
self.logger.debug("Building TensorRT engine...")
build_start = time.perf_counter()
serialized_engine = self.builder.build_serialized_network(network, config)
build_time = time.perf_counter() - build_start
if serialized_engine is None:
self.logger.error("Failed to build TensorRT engine")
return (None, None)
self.logger.debug(f"Engine built successfully in {build_time:.2f}s")
if flush_timing_cache:
self._save_timing_cache()
return (serialized_engine, build_time)
finally:
del parser, network, config
@staticmethod
def _alloc_pinned_host(size: int, dtype: np.dtype) -> tuple[Any, np.ndarray]:
"""Allocate pinned host memory using PyTorch and return (tensor, numpy_view).
Returns:
(host_tensor, arr): Pinned PyTorch tensor and a numpy view over it.
"""
torch_dtype = torch.from_numpy(np.empty(0, dtype=dtype)).dtype
host_tensor = torch.empty(int(size), dtype=torch_dtype).pin_memory()
return host_tensor, host_tensor.numpy()
@staticmethod
def _free_buffers(bufs: list[dict]) -> None:
"""Release buffer references; PyTorch handles underlying memory deallocation."""
bufs.clear()
def _allocate_buffers(
self,
engine: "trt.ICudaEngine",
context: "trt.IExecutionContext",
) -> tuple[list[dict], list[dict]]:
"""Allocate pinned host and device tensors for engine I/O and set tensor addresses.
Args:
engine: Deserialized TensorRT engine.
context: Execution context with tensor shapes set.
Returns:
(inputs, outputs): Lists of buffer dicts containing PyTorch tensors.
"""
inputs: list[dict] = []
outputs: list[dict] = []
for i in range(engine.num_io_tensors):
tensor_name = engine.get_tensor_name(i)
np_dtype = trt.nptype(engine.get_tensor_dtype(tensor_name))
shape = context.get_tensor_shape(tensor_name)
size = int(trt.volume(shape))
host_tensor, host_mem = self._alloc_pinned_host(size, np_dtype)
torch_dtype = torch.from_numpy(np.empty(0, dtype=np_dtype)).dtype
device_tensor = torch.empty(size, dtype=torch_dtype, device="cuda")
context.set_tensor_address(tensor_name, device_tensor.data_ptr())
if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
np.copyto(host_mem, np.random.randn(size).astype(np_dtype))
inputs.append({"host": host_tensor, "device": device_tensor, "name": tensor_name})
else:
outputs.append({"host": host_tensor, "device": device_tensor, "name": tensor_name})
return (inputs, outputs)
def _setup_execution_context(
self, serialized_engine: bytes
) -> tuple["trt.ICudaEngine | None", "trt.IExecutionContext | None"]:
"""Deserialize the engine and create an execution context.
Args:
serialized_engine: Serialized TensorRT engine bytes.
Returns:
(engine, context) or (None, None) if deserialization fails.
"""
engine = self.runtime.deserialize_cuda_engine(serialized_engine)
if engine is None:
self.logger.error("Failed to deserialize engine")
return (None, None)
context = engine.create_execution_context()
return (engine, context)
def _run_warmup(
self,
context: "trt.IExecutionContext",
inputs: list[dict],
outputs: list[dict],
stream: "torch.cuda.Stream",
) -> None:
"""Run warmup iterations to stabilize GPU state and cache."""
self.logger.debug(f"Running {self.warmup_runs} warmup iterations...")
with torch.cuda.stream(stream):
for _ in range(self.warmup_runs):
for inp in inputs:
inp["device"].copy_(inp["host"], non_blocking=True)
context.execute_async_v3(stream.cuda_stream)
for out in outputs:
out["host"].copy_(out["device"], non_blocking=True)
stream.synchronize()
def _run_timing(
self,
context: "trt.IExecutionContext",
inputs: list[dict],
outputs: list[dict],
stream: "torch.cuda.Stream",
) -> np.ndarray:
"""Run timing iterations and return per-run latencies in milliseconds."""
self.logger.debug(f"Running {self.timing_runs} timing iterations...")
latencies = []
with torch.cuda.stream(stream):
for _ in range(self.timing_runs):
for inp in inputs:
inp["device"].copy_(inp["host"], non_blocking=True)
stream.synchronize()
start = time.perf_counter()
context.execute_async_v3(stream.cuda_stream)
stream.synchronize()
end = time.perf_counter()
latencies.append((end - start) * 1000.0)
for out in outputs:
out["host"].copy_(out["device"], non_blocking=True)
return np.array(latencies)
def run(
self,
path_or_bytes: str | bytes,
log_file: str | None = None,
flush_timing_cache: bool = False,
) -> float:
"""Run benchmark using TensorRT Python API.
Args:
path_or_bytes: Path to the ONNX model (str) or raw model data (bytes)
log_file: Optional path to save benchmark logs
flush_timing_cache: If True, save the timing cache to disk after engine build.
Returns:
Measured median latency in milliseconds, or float("inf") on any error
(e.g. build failure, deserialization failure, buffer/stream allocation failure).
"""
serialized_engine = engine = context = stream = None
inputs, outputs = [], []
try:
serialized_engine, build_time = self._build_engine(path_or_bytes, flush_timing_cache)
if serialized_engine is None or build_time is None:
return float("inf")
engine, context = self._setup_execution_context(serialized_engine)
if engine is None or context is None:
return float("inf")
inputs, outputs = self._allocate_buffers(engine, context)
stream = torch.cuda.Stream()
self._run_warmup(context, inputs, outputs, stream)
latencies = self._run_timing(context, inputs, outputs, stream)
median_latency = float(np.median(latencies))
mean_latency = float(np.mean(latencies))
std_latency = float(np.std(latencies))
min_latency = float(np.min(latencies))
max_latency = float(np.max(latencies))
self.logger.info(
f"TensorRT Python API benchmark: min={min_latency:.3f}ms, max={max_latency:.3f}ms, "
f"mean={mean_latency:.3f}ms, std={std_latency:.3f}ms, median={median_latency:.3f}ms"
)
model_info = (
f"<bytes, size={len(path_or_bytes)}>"
if isinstance(path_or_bytes, bytes)
else path_or_bytes
)
self._write_log_file(
log_file,
"\n".join(
[
"TensorRT Python API Benchmark",
f"Model: {model_info}",
f"Build time: {build_time:.2f}s",
f"Warmup runs: {self.warmup_runs}",
f"Timing runs: {self.timing_runs}",
"Latency Statistics:",
f" Min: {min_latency:.3f} ms",
f" Max: {max_latency:.3f} ms",
f" Mean: {mean_latency:.3f} ms",
f" Std: {std_latency:.3f} ms",
f" Median: {median_latency:.3f} ms",
f"All latencies: {latencies.tolist()}",
]
),
)
return median_latency
except Exception as e:
self.logger.error(f"Benchmark failed: {e}", exc_info=True)
return float("inf")
finally:
try:
self._free_buffers(inputs)
self._free_buffers(outputs)
del inputs, outputs, stream, context, engine, serialized_engine
except Exception as cleanup_error:
self.logger.warning(f"Error during cleanup: {cleanup_error}")
def _load_timing_cache(self):
"""Load timing cache from file or create a new one."""
config = self.builder.create_builder_config()
if os.path.exists(self.timing_cache_file):
try:
with open(self.timing_cache_file, "rb") as f:
timing_cache_data = f.read()
self._timing_cache = config.create_timing_cache(timing_cache_data)
self.logger.debug(f"Loaded timing cache from: {self.timing_cache_file}")
except Exception as e:
self.logger.warning(f"Failed to load timing cache: {e}")
self.logger.debug("Creating new timing cache")
self._timing_cache = None
if self._timing_cache is None:
self._timing_cache = config.create_timing_cache(b"")
self.logger.debug("Created new timing cache")
del config
def _save_timing_cache(self):
"""Save timing cache to file."""
try:
if self._timing_cache is not None:
timing_cache_data = self._timing_cache.serialize()
with open(self.timing_cache_file, "wb") as f:
f.write(timing_cache_data)
self.logger.debug(f"Saved timing cache to: {self.timing_cache_file}")
except Exception as e:
self.logger.warning(f"Failed to save timing cache: {e}")