Skip to content

Commit f405dbb

Browse files
committed
simplify benchmark code
Signed-off-by: Will Guo <willg@nvidia.com>
1 parent 88f8d05 commit f405dbb

1 file changed

Lines changed: 15 additions & 36 deletions

File tree

modelopt/onnx/quantization/autotune/benchmark.py

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python3
21
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
32
# SPDX-License-Identifier: Apache-2.0
43
#
@@ -38,7 +37,6 @@
3837

3938
import numpy as np
4039

41-
# Optional dependencies - gracefully handle missing packages
4240
try:
4341
import tensorrt as trt
4442

@@ -165,11 +163,10 @@ def __init__(
165163
super().__init__(timing_cache_file, warmup_runs, timing_runs, plugin_libraries)
166164
self.trtexec_path = trtexec_path
167165
self.trtexec_args = trtexec_args if trtexec_args is not None else []
168-
self._temp_dir = tempfile.mkdtemp(prefix="trtexec_benchmark_")
169-
self.engine_dir = self._temp_dir
170-
self.engine_path = os.path.join(self.engine_dir, "engine.trt")
171-
self.temp_model_path = os.path.join(self.engine_dir, "temp_model.onnx")
172-
self.logger.debug(f"Created temporary engine directory: {self.engine_dir}")
166+
self.temp_dir = tempfile.mkdtemp(prefix="trtexec_benchmark_")
167+
self.engine_path = os.path.join(self.temp_dir, "engine.trt")
168+
self.temp_model_path = os.path.join(self.temp_dir, "temp_model.onnx")
169+
self.logger.debug(f"Created temporary engine directory: {self.temp_dir}")
173170
self.logger.debug(f"Temporary model path: {self.temp_model_path}")
174171
self.latency_pattern = r"\[I\]\s+Latency:.*?median\s*=\s*([\d.]+)\s*ms"
175172

@@ -213,10 +210,10 @@ def __init__(
213210

214211
def __del__(self):
215212
"""Cleanup temporary directory."""
216-
if hasattr(self, "_temp_dir"):
213+
if hasattr(self, "temp_dir"):
217214
try:
218-
shutil.rmtree(self._temp_dir, ignore_errors=True)
219-
self.logger.debug(f"Cleaned up temporary directory: {self._temp_dir}")
215+
shutil.rmtree(self.temp_dir, ignore_errors=True)
216+
self.logger.debug(f"Cleaned up temporary directory: {self.temp_dir}")
220217
except Exception as e:
221218
self.logger.warning(f"Failed to cleanup temporary directory: {e}")
222219

@@ -344,13 +341,8 @@ def __init__(
344341

345342
self.network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
346343
self.network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
347-
348-
# Load timing cache from disk or create new one
349344
self._timing_cache = None
350345
self._load_timing_cache()
351-
352-
# Storage for user-defined shape configurations
353-
# Format: {input_name: (min_shape, opt_shape, max_shape)}
354346
self._shape_configs = {}
355347

356348
def _load_plugin_libraries(self):
@@ -600,9 +592,8 @@ def run(
600592
min_latency = float(np.min(latencies))
601593
max_latency = float(np.max(latencies))
602594

603-
self.logger.info("TensorRT Python API benchmark:")
604595
self.logger.info(
605-
f" min={min_latency:.3f}ms, max={max_latency:.3f}ms, "
596+
f"TensorRT Python API benchmark: min={min_latency:.3f}ms, max={max_latency:.3f}ms, "
606597
f"mean={mean_latency:.3f}ms, std={std_latency:.3f}ms, median={median_latency:.3f}ms"
607598
)
608599

@@ -639,33 +630,21 @@ def run(
639630
return float("inf")
640631
finally:
641632
try:
633+
[inp["device"].free() for inp in inputs if "device" in inp]
634+
[out["device"].free() for out in outputs if "device" in out]
642635
for inp in inputs:
643-
if "device" in inp:
644-
inp["device"].free()
645636
if "host" in inp:
646637
del inp["host"]
647638
for out in outputs:
648-
if "device" in out:
649-
out["device"].free()
650639
if "host" in out:
651640
del out["host"]
652641
inputs.clear()
653642
outputs.clear()
654-
655-
if context is not None:
656-
del context
657-
if stream is not None:
658-
del stream
659-
if engine is not None:
660-
del engine
661-
if serialized_engine is not None:
662-
del serialized_engine
663-
if parser is not None:
664-
del parser
665-
if network is not None:
666-
del network
667-
if config is not None:
668-
del config
643+
resources = [context, stream, engine, serialized_engine, parser, network, config]
644+
for resource in resources:
645+
if resource is not None:
646+
del resource
647+
resources.clear()
669648
except Exception as cleanup_error:
670649
self.logger.warning(f"Error during cleanup: {cleanup_error}")
671650

0 commit comments

Comments
 (0)