Skip to content

Commit 3112a4e

Browse files
committed
Integrate Automated QDQ placement tool - part 3.3
Signed-off-by: Will Guo <willg@nvidia.com>
1 parent e53ca61 commit 3112a4e

3 files changed

Lines changed: 790 additions & 0 deletions

File tree

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Command-line interface for ONNX Q/DQ autotuning."""
18+
19+
import argparse
20+
import sys
21+
from pathlib import Path
22+
23+
from modelopt.onnx.logging_config import logger
24+
from modelopt.onnx.quantization.autotune.workflows import (
25+
init_benchmark_instance,
26+
region_pattern_autotuning_workflow,
27+
)
28+
29+
DEFAULT_OUTPUT_DIR = "./autotuner_output"
30+
DEFAULT_NUM_SCHEMES = 30
31+
DEFAULT_QUANT_TYPE = "int8"
32+
DEFAULT_DQ_DTYPE = "float32"
33+
DEFAULT_TIMING_CACHE = "/tmp/trtexec_timing.cache" # nosec B108
34+
DEFAULT_WARMUP_RUNS = 5
35+
DEFAULT_TIMING_RUNS = 20
36+
37+
38+
def validate_file_path(path: str | None, description: str) -> Path | None:
39+
"""Validate that a file path exists.
40+
41+
Args:
42+
path: Path string to validate (can be None)
43+
description: Description of the file for error messages
44+
45+
Returns:
46+
Path object if valid, None if path is None
47+
48+
Raises:
49+
SystemExit: If path is provided but doesn't exist
50+
"""
51+
if path is None:
52+
return None
53+
54+
path_obj = Path(path)
55+
if not path_obj.exists():
56+
logger.error(f"{description} not found: {path_obj}")
57+
sys.exit(1)
58+
59+
return path_obj
60+
61+
62+
def log_benchmark_config(args):
63+
"""Log TensorRT benchmark configuration for transparency.
64+
65+
Logs timing cache path, warmup/timing run counts, and any custom
66+
plugin libraries that will be loaded.
67+
68+
Args:
69+
args: Parsed command-line arguments with benchmark configuration
70+
"""
71+
logger.info("Initializing TensorRT benchmark")
72+
logger.info(f" Timing cache: {args.timing_cache}")
73+
logger.info(f" Warmup runs: {args.warmup_runs}")
74+
logger.info(f" Timing runs: {args.timing_runs}")
75+
if args.plugin_libraries:
76+
logger.info(f" Plugin libraries: {', '.join(args.plugin_libraries)}")
77+
if hasattr(args, "trtexec_benchmark_args") and args.trtexec_benchmark_args:
78+
logger.info(f" Trtexec args: {args.trtexec_benchmark_args}")
79+
80+
81+
def run_autotune(args=None) -> int:
82+
"""Execute the complete pattern-based Q/DQ autotuning workflow.
83+
84+
This function orchestrates the entire optimization process:
85+
1. Parses command-line arguments (if not provided)
86+
2. Validates input paths (model, baseline, output directory)
87+
3. Initializes TensorRT benchmark instance
88+
4. Runs pattern-based region autotuning workflow
89+
5. Handles interruptions gracefully with state preservation
90+
91+
Args:
92+
args: Optional parsed command-line arguments. If None, parses sys.argv.
93+
94+
Returns:
95+
Exit code:
96+
- 0: Success
97+
- 1: Autotuning failed (exception occurred)
98+
- 130: Interrupted by user (Ctrl+C)
99+
"""
100+
if args is None:
101+
args = _get_autotune_parser().parse_args()
102+
103+
model_path = validate_file_path(args.onnx_path, "Model file")
104+
validate_file_path(args.qdq_baseline, "QDQ baseline model")
105+
output_dir = Path(args.output)
106+
107+
log_benchmark_config(args)
108+
trtexec_args = getattr(args, "trtexec_benchmark_args", None)
109+
benchmark_instance = init_benchmark_instance(
110+
use_trtexec=args.use_trtexec,
111+
plugin_libraries=args.plugin_libraries,
112+
timing_cache_file=args.timing_cache,
113+
warmup_runs=args.warmup_runs,
114+
timing_runs=args.timing_runs,
115+
trtexec_args=trtexec_args,
116+
)
117+
118+
if benchmark_instance is None:
119+
logger.error("Failed to initialize TensorRT benchmark")
120+
return 1
121+
122+
logger.info("Autotuning Mode: Pattern-Based")
123+
124+
try:
125+
node_filter_list = None
126+
if args.node_filter_list:
127+
filter_file = validate_file_path(args.node_filter_list, "Node filter list file")
128+
if filter_file:
129+
with open(filter_file) as f:
130+
node_filter_list = [
131+
line.strip()
132+
for line in f
133+
if line.strip() and not line.strip().startswith("#")
134+
]
135+
logger.info(f"Loaded {len(node_filter_list)} filter patterns from {filter_file}")
136+
137+
region_pattern_autotuning_workflow(
138+
model_path=str(model_path),
139+
output_dir=output_dir,
140+
num_schemes_per_region=args.num_schemes,
141+
pattern_cache_file=args.pattern_cache_file,
142+
state_file=args.state_file,
143+
quant_type=args.quant_type,
144+
default_dq_dtype=args.default_dq_dtype,
145+
qdq_baseline_model=args.qdq_baseline,
146+
node_filter_list=node_filter_list,
147+
verbose=args.verbose,
148+
)
149+
150+
logger.info("\n" + "=" * 70)
151+
logger.info("✓ Autotuning completed successfully!")
152+
logger.info(f"✓ Results: {output_dir}")
153+
logger.info("=" * 70)
154+
return 0
155+
156+
except KeyboardInterrupt:
157+
logger.warning("\nInterrupted by user")
158+
state_file = args.state_file or output_dir / "autotuner_state.yaml"
159+
logger.info(f"Progress saved to: {state_file}")
160+
return 130
161+
162+
except Exception as e:
163+
logger.error(f"\nAutotuning failed: {e}", exc_info=args.verbose)
164+
return 1
165+
166+
167+
def _get_autotune_parser() -> argparse.ArgumentParser:
168+
"""Create and configure the command-line argument parser."""
169+
parser = argparse.ArgumentParser(
170+
prog="modelopt.onnx.quantization.autotune",
171+
description="ONNX Q/DQ Autotuning with TensorRT",
172+
formatter_class=argparse.RawDescriptionHelpFormatter,
173+
epilog="""
174+
Examples:
175+
# Basic usage
176+
python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx
177+
178+
# Import patterns from QDQ baseline model
179+
python -m modelopt.onnx.quantization.autotune \\
180+
--onnx_path model.onnx --qdq_baseline baseline.onnx
181+
182+
# Use pattern cache for warm-start
183+
python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx --pattern_cache cache.yaml
184+
185+
# Full example with all options
186+
python -m modelopt.onnx.quantization.autotune \\
187+
--onnx_path model.onnx --schemes_per_region 50 \\
188+
--pattern_cache cache.yaml --qdq_baseline baseline.onnx \\
189+
--quant_type int8 --verbose
190+
""",
191+
)
192+
193+
# Model and Output
194+
io_group = parser.add_argument_group("Model and Output")
195+
io_group.add_argument(
196+
"--onnx_path", "-m", type=str, required=True, help="Path to ONNX model file"
197+
)
198+
io_group.add_argument(
199+
"--output",
200+
"-o",
201+
type=str,
202+
default=DEFAULT_OUTPUT_DIR,
203+
help=f"Output directory for results (default: {DEFAULT_OUTPUT_DIR})",
204+
)
205+
206+
# Autotuning Strategy
207+
strategy_group = parser.add_argument_group("Autotuning Strategy")
208+
strategy_group.add_argument(
209+
"--schemes_per_region",
210+
"-s",
211+
type=int,
212+
default=DEFAULT_NUM_SCHEMES,
213+
dest="num_schemes",
214+
help=f"Number of schemes to test per region (default: {DEFAULT_NUM_SCHEMES})",
215+
)
216+
strategy_group.add_argument(
217+
"--pattern_cache",
218+
type=str,
219+
default=None,
220+
dest="pattern_cache_file",
221+
help="Path to pattern cache YAML for warm-start (optional)",
222+
)
223+
strategy_group.add_argument(
224+
"--qdq_baseline",
225+
type=str,
226+
default=None,
227+
help="Path to QDQ baseline ONNX model to import quantization patterns (optional)",
228+
)
229+
strategy_group.add_argument(
230+
"--state_file",
231+
type=str,
232+
default=None,
233+
help="State file path for resume capability (default: <output>/autotuner_state.yaml)",
234+
)
235+
strategy_group.add_argument(
236+
"--node_filter_list",
237+
type=str,
238+
default=None,
239+
help="Path to a file containing wildcard patterns to filter ONNX nodes (one pattern per line). "
240+
"Regions without any matching nodes are skipped during autotuning.",
241+
)
242+
243+
# Quantization
244+
quant_group = parser.add_argument_group("Quantization")
245+
quant_group.add_argument(
246+
"--quant_type",
247+
type=str,
248+
default=DEFAULT_QUANT_TYPE,
249+
choices=["int8", "fp8"],
250+
help=f"Quantization data type (default: {DEFAULT_QUANT_TYPE})",
251+
)
252+
quant_group.add_argument(
253+
"--default_dq_dtype",
254+
type=str,
255+
default=DEFAULT_DQ_DTYPE,
256+
choices=["float16", "float32", "bfloat16"],
257+
help="Default DQ output dtype if cannot be deduced (optional)",
258+
)
259+
260+
# TensorRT Benchmark
261+
trt_group = parser.add_argument_group("TensorRT Benchmark")
262+
trt_group.add_argument(
263+
"--use_trtexec",
264+
action="store_true",
265+
help="Use trtexec for benchmarking (default: False)",
266+
default=False,
267+
)
268+
trt_group.add_argument(
269+
"--timing_cache",
270+
type=str,
271+
default=DEFAULT_TIMING_CACHE,
272+
help=f"TensorRT timing cache file (default: {DEFAULT_TIMING_CACHE})",
273+
)
274+
trt_group.add_argument(
275+
"--warmup_runs",
276+
type=int,
277+
default=DEFAULT_WARMUP_RUNS,
278+
help=f"Number of warmup runs (default: {DEFAULT_WARMUP_RUNS})",
279+
)
280+
trt_group.add_argument(
281+
"--timing_runs",
282+
type=int,
283+
default=DEFAULT_TIMING_RUNS,
284+
help=f"Number of timing runs (default: {DEFAULT_TIMING_RUNS})",
285+
)
286+
trt_group.add_argument(
287+
"--plugin_libraries",
288+
"--plugins",
289+
type=str,
290+
nargs="+",
291+
default=None,
292+
dest="plugin_libraries",
293+
help="TensorRT plugin libraries (.so files) to load (optional, space-separated)",
294+
)
295+
trt_group.add_argument(
296+
"--trtexec_benchmark_args",
297+
type=str,
298+
default=None,
299+
help="Additional command-line arguments to pass to trtexec as a single quoted string. "
300+
"Example: --trtexec_benchmark_args '--fp16 --workspace=4096 --verbose'",
301+
)
302+
303+
# Logging
304+
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose DEBUG logging")
305+
306+
return parser
307+
308+
309+
if __name__ == "__main__":
310+
sys.exit(run_autotune())

0 commit comments

Comments
 (0)