Skip to content

Commit 94b2cb2

Browse files
authored
Integrate Automated QDQ placement tool - part 3.3 (#839)
## What does this PR do? This PR implements QDQ autotuner CLI. This is the initial version of CLI, it will be integrated to modelopt.onnx.quantization.autotune. Usage: ``` python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx --schemes_per_region 50 --pattern_cache cache.yaml --qdq_baseline baseline.onnx --quant_type int8 --verbose ``` PR 3.1: #837 PR 3.2 #838 PR 3.3: #839 **Overview:** ? ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: Document will be added in part 4. - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: CHANGE log will be added in part 4. ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added a command-line interface for ONNX quantization autotuning with configurable parameters for models, output paths, quantization strategies, and TensorRT benchmarking. * Introduced an automated workflow for pattern-based region optimization with state management, baseline comparison, and benchmarking capabilities. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Will Guo <willg@nvidia.com>
1 parent fcdaf65 commit 94b2cb2

File tree

6 files changed

+869
-5
lines changed

6 files changed

+869
-5
lines changed
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Command-line interface for ONNX Q/DQ autotuning."""
17+
18+
import argparse
19+
import sys
20+
import tempfile
21+
from pathlib import Path
22+
23+
from modelopt.onnx.logging_config import logger
24+
from modelopt.onnx.quantization.autotune.workflows import (
25+
init_benchmark_instance,
26+
region_pattern_autotuning_workflow,
27+
)
28+
29+
DEFAULT_OUTPUT_DIR = "./autotuner_output"
30+
DEFAULT_NUM_SCHEMES = 30
31+
DEFAULT_QUANT_TYPE = "int8"
32+
DEFAULT_DQ_DTYPE = "float32"
33+
DEFAULT_TIMING_CACHE = str(Path(tempfile.gettempdir()) / "trtexec_timing.cache")
34+
DEFAULT_WARMUP_RUNS = 5
35+
DEFAULT_TIMING_RUNS = 20
36+
37+
38+
def validate_file_path(path: str | None, description: str) -> Path | None:
39+
"""Validate that a file path exists.
40+
41+
Args:
42+
path: Path string to validate (can be None)
43+
description: Description of the file for error messages
44+
45+
Returns:
46+
Path object if valid, None if path is None
47+
48+
Raises:
49+
SystemExit: If path is provided but doesn't exist
50+
"""
51+
if path is None:
52+
return None
53+
54+
path_obj = Path(path)
55+
if not path_obj.exists():
56+
logger.error(f"{description} not found: {path_obj}")
57+
sys.exit(1)
58+
59+
return path_obj
60+
61+
62+
def log_benchmark_config(args):
63+
"""Log TensorRT benchmark configuration for transparency.
64+
65+
Logs timing cache path, warmup/timing run counts, and any custom
66+
plugin libraries that will be loaded.
67+
68+
Args:
69+
args: Parsed command-line arguments with benchmark configuration
70+
"""
71+
logger.info("Initializing TensorRT benchmark")
72+
logger.info(f" Timing cache: {args.timing_cache}")
73+
logger.info(f" Warmup runs: {args.warmup_runs}")
74+
logger.info(f" Timing runs: {args.timing_runs}")
75+
if args.plugin_libraries:
76+
logger.info(f" Plugin libraries: {', '.join(args.plugin_libraries)}")
77+
if hasattr(args, "trtexec_benchmark_args") and args.trtexec_benchmark_args:
78+
logger.info(f" Trtexec args: {args.trtexec_benchmark_args}")
79+
80+
81+
def run_autotune() -> int:
82+
"""Execute the complete pattern-based Q/DQ autotuning workflow.
83+
84+
Parses command-line arguments, then:
85+
1. Validates input paths (model, baseline, output directory)
86+
2. Initializes TensorRT benchmark instance
87+
3. Runs pattern-based region autotuning workflow
88+
4. Handles interruptions gracefully with state preservation
89+
90+
Returns:
91+
Exit code:
92+
- 0: Success
93+
- 1: Autotuning failed (exception occurred)
94+
- 130: Interrupted by user (Ctrl+C)
95+
"""
96+
args = _get_autotune_parser().parse_args()
97+
model_path = validate_file_path(args.onnx_path, "Model file")
98+
validate_file_path(args.qdq_baseline, "QDQ baseline model")
99+
output_dir = Path(args.output_dir)
100+
101+
log_benchmark_config(args)
102+
trtexec_args = getattr(args, "trtexec_benchmark_args", None)
103+
benchmark_instance = init_benchmark_instance(
104+
use_trtexec=args.use_trtexec,
105+
plugin_libraries=args.plugin_libraries,
106+
timing_cache_file=args.timing_cache,
107+
warmup_runs=args.warmup_runs,
108+
timing_runs=args.timing_runs,
109+
trtexec_args=trtexec_args,
110+
)
111+
112+
if benchmark_instance is None:
113+
logger.error("Failed to initialize TensorRT benchmark")
114+
return 1
115+
116+
try:
117+
node_filter_list = None
118+
if args.node_filter_list:
119+
filter_file = validate_file_path(args.node_filter_list, "Node filter list file")
120+
if filter_file:
121+
with open(filter_file) as f:
122+
node_filter_list = [
123+
line.strip()
124+
for line in f
125+
if line.strip() and not line.strip().startswith("#")
126+
]
127+
logger.info(f"Loaded {len(node_filter_list)} filter patterns from {filter_file}")
128+
129+
region_pattern_autotuning_workflow(
130+
model_path=str(model_path),
131+
output_dir=output_dir,
132+
num_schemes_per_region=args.num_schemes,
133+
pattern_cache_file=args.pattern_cache_file,
134+
state_file=args.state_file,
135+
quant_type=args.quant_type,
136+
default_dq_dtype=args.default_dq_dtype,
137+
qdq_baseline_model=args.qdq_baseline,
138+
node_filter_list=node_filter_list,
139+
verbose=args.verbose,
140+
)
141+
142+
logger.info("\n" + "=" * 70)
143+
logger.info("✓ Autotuning completed successfully!")
144+
logger.info(f"✓ Results: {output_dir}")
145+
logger.info("=" * 70)
146+
return 0
147+
148+
except KeyboardInterrupt:
149+
logger.warning("\nInterrupted by user")
150+
state_file = args.state_file or output_dir / "autotuner_state.yaml"
151+
logger.info(f"Progress saved to: {state_file}")
152+
return 130
153+
154+
except Exception as e:
155+
logger.error(f"\nAutotuning failed: {e}", exc_info=args.verbose)
156+
return 1
157+
158+
159+
def _get_autotune_parser() -> argparse.ArgumentParser:
160+
"""Create and configure the command-line argument parser."""
161+
parser = argparse.ArgumentParser(
162+
prog="modelopt.onnx.quantization.autotune",
163+
description="ONNX Q/DQ Autotuning with TensorRT",
164+
formatter_class=argparse.RawDescriptionHelpFormatter,
165+
epilog="""
166+
Examples:
167+
# Basic usage
168+
python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx
169+
170+
# Import patterns from QDQ baseline model
171+
python -m modelopt.onnx.quantization.autotune \\
172+
--onnx_path model.onnx --qdq_baseline baseline.onnx
173+
174+
# Use pattern cache for warm-start
175+
python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx --pattern_cache cache.yaml
176+
177+
# Full example with all options
178+
python -m modelopt.onnx.quantization.autotune \\
179+
--onnx_path model.onnx --schemes_per_region 50 \\
180+
--pattern_cache cache.yaml --qdq_baseline baseline.onnx \\
181+
--quant_type int8 --verbose
182+
""",
183+
)
184+
185+
# Model and Output
186+
io_group = parser.add_argument_group("Model and Output")
187+
io_group.add_argument(
188+
"--onnx_path", "-m", type=str, required=True, help="Path to ONNX model file"
189+
)
190+
io_group.add_argument(
191+
"--output_dir",
192+
"-o",
193+
type=str,
194+
default=DEFAULT_OUTPUT_DIR,
195+
dest="output_dir",
196+
help=f"Output directory for results (default: {DEFAULT_OUTPUT_DIR})",
197+
)
198+
199+
# Autotuning Strategy
200+
strategy_group = parser.add_argument_group("Autotuning Strategy")
201+
strategy_group.add_argument(
202+
"--schemes_per_region",
203+
"-s",
204+
type=int,
205+
default=DEFAULT_NUM_SCHEMES,
206+
dest="num_schemes",
207+
help=f"Number of schemes to test per region (default: {DEFAULT_NUM_SCHEMES})",
208+
)
209+
strategy_group.add_argument(
210+
"--pattern_cache",
211+
type=str,
212+
default=None,
213+
dest="pattern_cache_file",
214+
help="Path to pattern cache YAML for warm-start (optional)",
215+
)
216+
strategy_group.add_argument(
217+
"--qdq_baseline",
218+
type=str,
219+
default=None,
220+
help="Path to QDQ baseline ONNX model to import quantization patterns (optional)",
221+
)
222+
strategy_group.add_argument(
223+
"--state_file",
224+
type=str,
225+
default=None,
226+
help="State file path for resume capability (default: <output_dir>/autotuner_state.yaml)",
227+
)
228+
strategy_group.add_argument(
229+
"--node_filter_list",
230+
type=str,
231+
default=None,
232+
help="Path to a file containing wildcard patterns to filter ONNX nodes (one pattern per line). "
233+
"Regions without any matching nodes are skipped during autotuning.",
234+
)
235+
236+
# Quantization
237+
quant_group = parser.add_argument_group("Quantization")
238+
quant_group.add_argument(
239+
"--quant_type",
240+
type=str,
241+
default=DEFAULT_QUANT_TYPE,
242+
choices=["int8", "fp8"],
243+
help=f"Quantization data type (default: {DEFAULT_QUANT_TYPE})",
244+
)
245+
quant_group.add_argument(
246+
"--default_dq_dtype",
247+
type=str,
248+
default=DEFAULT_DQ_DTYPE,
249+
choices=["float16", "float32", "bfloat16"],
250+
help="Default DQ output dtype if cannot be deduced (optional)",
251+
)
252+
253+
# TensorRT Benchmark
254+
trt_group = parser.add_argument_group("TensorRT Benchmark")
255+
trt_group.add_argument(
256+
"--use_trtexec",
257+
action="store_true",
258+
help="Use trtexec for benchmarking (default: False)",
259+
default=False,
260+
)
261+
trt_group.add_argument(
262+
"--timing_cache",
263+
type=str,
264+
default=DEFAULT_TIMING_CACHE,
265+
help=f"TensorRT timing cache file (default: {DEFAULT_TIMING_CACHE})",
266+
)
267+
trt_group.add_argument(
268+
"--warmup_runs",
269+
type=int,
270+
default=DEFAULT_WARMUP_RUNS,
271+
help=f"Number of warmup runs (default: {DEFAULT_WARMUP_RUNS})",
272+
)
273+
trt_group.add_argument(
274+
"--timing_runs",
275+
type=int,
276+
default=DEFAULT_TIMING_RUNS,
277+
help=f"Number of timing runs (default: {DEFAULT_TIMING_RUNS})",
278+
)
279+
trt_group.add_argument(
280+
"--plugin_libraries",
281+
"--plugins",
282+
type=str,
283+
nargs="+",
284+
default=None,
285+
dest="plugin_libraries",
286+
help="TensorRT plugin libraries (.so files) to load (optional, space-separated)",
287+
)
288+
trt_group.add_argument(
289+
"--trtexec_benchmark_args",
290+
type=str,
291+
default=None,
292+
help="Additional command-line arguments to pass to trtexec as a single quoted string. "
293+
"Example: --trtexec_benchmark_args '--fp16 --workspace=4096 --verbose'",
294+
)
295+
296+
# Logging
297+
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose DEBUG logging")
298+
299+
return parser
300+
301+
302+
if __name__ == "__main__":
303+
sys.exit(run_autotune())

modelopt/onnx/quantization/autotune/common.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,11 @@ def from_dict(
447447
def __str__(self) -> str:
448448
"""String representation for debugging."""
449449
best_latency = self.best_scheme.latency_ms if self.best_scheme else 0.0
450+
pattern_str = self.pattern_signature[:40] + (
451+
"..." if len(self.pattern_signature) > 40 else ""
452+
)
450453
return (
451-
f"PatternSchemes(pattern='{self.pattern_signature[:40]}...', "
454+
f"PatternSchemes(pattern='{pattern_str}', "
452455
f"schemes={self.num_schemes}, best_latency={best_latency:.3f}ms)"
453456
)
454457

@@ -516,19 +519,22 @@ def add_pattern_schemes(self, pattern_schemes: PatternSchemes) -> None:
516519
for scheme in sorted_schemes:
517520
# Check if this scheme is too similar to any already-filtered scheme
518521
too_similar = False
522+
existing_to_remove = None # at most one; remove after inner loop
519523
for existing_scheme in filtered_schemes:
520524
distance = scheme.distance(existing_scheme)
521525
if distance < self.minimum_distance:
522526
# Schemes are too similar, keep the better one
523527
if scheme.latency_ms < existing_scheme.latency_ms:
524-
# New scheme is better, remove existing and add new
525-
filtered_schemes.remove(existing_scheme)
528+
# New scheme is better; mark existing for removal
529+
existing_to_remove = existing_scheme
526530
break
527531
else:
528532
# Existing scheme is better, skip new one
529533
too_similar = True
530534
break
531535

536+
if existing_to_remove is not None:
537+
filtered_schemes.remove(existing_to_remove)
532538
if not too_similar:
533539
filtered_schemes.append(scheme)
534540

0 commit comments

Comments
 (0)