Skip to content

Commit 54cefb2

Browse files
ajrasaneclaude
andcommitted
Add --trt_build flag to torch_quant_to_onnx and simplify tests
Move TRT engine build logic into the script as a --trt_build flag, removing the duplicate trtexec invocation from the test file. Signed-off-by: ajrasane <arasane@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 9810e1d commit 54cefb2

File tree

2 files changed

+27
-34
lines changed

2 files changed

+27
-34
lines changed

examples/torch_onnx/torch_quant_to_onnx.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import copy
1818
import json
1919
import re
20+
import subprocess
2021
import sys
2122
import warnings
2223
from pathlib import Path
@@ -395,6 +396,11 @@ def main():
395396
default=128,
396397
help="Number of scoring steps for auto quantization. Default is 128.",
397398
)
399+
parser.add_argument(
400+
"--trt_build",
401+
action="store_true",
402+
help="Build a TensorRT engine from the exported ONNX model using trtexec.",
403+
)
398404
parser.add_argument(
399405
"--no_pretrained",
400406
action="store_true",
@@ -496,6 +502,26 @@ def main():
496502

497503
print(f"Quantized ONNX model is saved to {args.onnx_save_path}")
498504

505+
if args.trt_build:
506+
build_trt_engine(args.onnx_save_path)
507+
508+
509+
def build_trt_engine(onnx_path):
510+
"""Build a TensorRT engine from the exported ONNX model using trtexec."""
511+
cmd = [
512+
"trtexec",
513+
f"--onnx={onnx_path}",
514+
"--stronglyTyped",
515+
"--builderOptimizationLevel=4",
516+
]
517+
print(f"\nBuilding TensorRT engine: {' '.join(cmd)}")
518+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
519+
if result.returncode != 0:
520+
raise RuntimeError(
521+
f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}"
522+
)
523+
print("TensorRT engine build succeeded.")
524+
499525

500526
if __name__ == "__main__":
501527
main()

tests/examples/torch_onnx/test_torch_quant_to_onnx.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
# limitations under the License.
1515

1616

17-
import os
18-
import subprocess
19-
2017
import pytest
2118
from _test_utils.examples.run_command import extend_cmd_parts, run_example_command
2219

@@ -31,40 +28,13 @@
3128
"resnet50": ("resnet50", None),
3229
}
3330

34-
# Builder optimization level: 4 for low-bit modes, 3 otherwise
35-
_LOW_BIT_MODES = {"fp8", "int8", "nvfp4"}
36-
37-
38-
def _verify_trt_engine_build(onnx_save_path, quantize_mode):
39-
"""Verify the exported ONNX model can be compiled into a TensorRT engine."""
40-
example_dir = os.path.join(
41-
os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx"
42-
)
43-
onnx_path = os.path.join(example_dir, onnx_save_path)
44-
assert os.path.exists(onnx_path), f"ONNX file not found: {onnx_path}"
45-
46-
opt_level = "4" if quantize_mode in _LOW_BIT_MODES else "3"
47-
cmd = [
48-
"trtexec",
49-
f"--onnx={onnx_path}",
50-
"--stronglyTyped",
51-
f"--builderOptimizationLevel={opt_level}",
52-
]
53-
54-
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
55-
assert result.returncode == 0, (
56-
f"TensorRT engine build failed for {onnx_save_path} "
57-
f"(mode={quantize_mode}):\n{result.stdout}\n{result.stderr}"
58-
)
59-
6031

6132
@pytest.mark.parametrize("quantize_mode", _QUANT_MODES)
6233
@pytest.mark.parametrize("model_key", list(_MODELS))
6334
def test_torch_onnx(model_key, quantize_mode):
6435
timm_model_name, model_kwargs = _MODELS[model_key]
6536
onnx_save_path = f"{model_key}.{quantize_mode}.onnx"
6637

67-
# Step 1: Quantize and export to ONNX
6838
cmd_parts = extend_cmd_parts(
6939
["python", "torch_quant_to_onnx.py"],
7040
timm_model_name=timm_model_name,
@@ -74,8 +44,5 @@ def test_torch_onnx(model_key, quantize_mode):
7444
calibration_data_size="1",
7545
num_score_steps="1",
7646
)
77-
cmd_parts.append("--no_pretrained")
47+
cmd_parts.extend(["--no_pretrained", "--trt_build"])
7848
run_example_command(cmd_parts, "torch_onnx")
79-
80-
# Step 2: Verify the exported ONNX model builds a TensorRT engine
81-
_verify_trt_engine_build(onnx_save_path, quantize_mode)

0 commit comments

Comments
 (0)