Skip to content

Commit 0171a0c

Browse files
committed
Integrate Automated QDQ placement tool - part 4.4
Signed-off-by: Will Guo <willg@nvidia.com>
1 parent 2242304 commit 0171a0c

6 files changed

Lines changed: 352 additions & 9 deletions

File tree

modelopt/onnx/quantization/autotune/__main__.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,40 @@
3333
DEFAULT_TIMING_CACHE = str(Path(tempfile.gettempdir()) / "trtexec_timing.cache")
3434
DEFAULT_WARMUP_RUNS = 5
3535
DEFAULT_TIMING_RUNS = 20
36+
MODE_PRESETS = {
37+
"quick": {"schemes_per_region": 30, "warmup_runs": 10, "timing_runs": 50},
38+
"default": {"schemes_per_region": 50, "warmup_runs": 50, "timing_runs": 100},
39+
"extensive": {"schemes_per_region": 200, "warmup_runs": 50, "timing_runs": 200},
40+
}
41+
42+
43+
class _StoreWithExplicitFlag(argparse.Action):
44+
"""Store the value and set an 'explicit' flag on the namespace so mode presets do not override."""
45+
46+
def __init__(self, explicit_attr: str, *args, **kwargs):
47+
self._explicit_attr = explicit_attr
48+
super().__init__(*args, **kwargs)
49+
50+
def __call__(self, parser, namespace, values, option_string=None):
51+
setattr(namespace, self.dest, values)
52+
setattr(namespace, self._explicit_attr, True)
53+
54+
55+
def apply_mode_presets(args) -> None:
56+
"""Apply --mode preset to schemes_per_region, warmup_runs, timing_runs.
57+
58+
Only applies preset for an option when that option was not explicitly set on the
59+
command line (explicit flags override the preset even when the value equals the default).
60+
"""
61+
if args.mode not in MODE_PRESETS:
62+
return
63+
preset = MODE_PRESETS[args.mode]
64+
if not getattr(args, "_explicit_num_schemes", False):
65+
args.num_schemes = preset["schemes_per_region"]
66+
if not getattr(args, "_explicit_warmup_runs", False):
67+
args.warmup_runs = preset["warmup_runs"]
68+
if not getattr(args, "_explicit_timing_runs", False):
69+
args.timing_runs = preset["timing_runs"]
3670

3771

3872
def validate_file_path(path: str | None, description: str) -> Path | None:
@@ -94,6 +128,7 @@ def run_autotune() -> int:
94128
- 130: Interrupted by user (Ctrl+C)
95129
"""
96130
args = _get_autotune_parser().parse_args()
131+
apply_mode_presets(args)
97132
model_path = validate_file_path(args.onnx_path, "Model file")
98133
validate_file_path(args.qdq_baseline, "QDQ baseline model")
99134
output_dir = Path(args.output_dir)
@@ -167,6 +202,12 @@ def _get_autotune_parser() -> argparse.ArgumentParser:
167202
# Basic usage
168203
python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx
169204
205+
# Quick mode (fewer schemes and benchmark runs for fast iteration)
206+
python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx --mode quick
207+
208+
# Extensive mode (more schemes and runs for thorough tuning)
209+
python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx --mode extensive
210+
170211
# Import patterns from QDQ baseline model
171212
python -m modelopt.onnx.quantization.autotune \\
172213
--onnx_path model.onnx --qdq_baseline baseline.onnx
@@ -198,13 +239,26 @@ def _get_autotune_parser() -> argparse.ArgumentParser:
198239

199240
# Autotuning Strategy
200241
strategy_group = parser.add_argument_group("Autotuning Strategy")
242+
strategy_group.add_argument(
243+
"--mode",
244+
type=str,
245+
default="default",
246+
choices=["quick", "default", "extensive"],
247+
help="Preset for schemes_per_region, warmup_runs, and timing_runs. "
248+
"'quick': fewer schemes/runs for fast iteration; "
249+
"'default': balanced; "
250+
"'extensive': more schemes/runs for thorough tuning. "
251+
"Explicit --schemes_per_region, --warmup_runs, --timing_runs override the preset.",
252+
)
201253
strategy_group.add_argument(
202254
"--schemes_per_region",
203255
"-s",
204256
type=int,
205257
default=DEFAULT_NUM_SCHEMES,
206258
dest="num_schemes",
207-
help=f"Number of schemes to test per region (default: {DEFAULT_NUM_SCHEMES})",
259+
action=_StoreWithExplicitFlag,
260+
explicit_attr="_explicit_num_schemes",
261+
help=f"Number of schemes to test per region (default: {DEFAULT_NUM_SCHEMES}; overridden by --mode)",
208262
)
209263
strategy_group.add_argument(
210264
"--pattern_cache",
@@ -268,13 +322,17 @@ def _get_autotune_parser() -> argparse.ArgumentParser:
268322
"--warmup_runs",
269323
type=int,
270324
default=DEFAULT_WARMUP_RUNS,
271-
help=f"Number of warmup runs (default: {DEFAULT_WARMUP_RUNS})",
325+
action=_StoreWithExplicitFlag,
326+
explicit_attr="_explicit_warmup_runs",
327+
help=f"Number of warmup runs (default: {DEFAULT_WARMUP_RUNS}; overridden by --mode)",
272328
)
273329
trt_group.add_argument(
274330
"--timing_runs",
275331
type=int,
276332
default=DEFAULT_TIMING_RUNS,
277-
help=f"Number of timing runs (default: {DEFAULT_TIMING_RUNS})",
333+
action=_StoreWithExplicitFlag,
334+
explicit_attr="_explicit_timing_runs",
335+
help=f"Number of timing runs (default: {DEFAULT_TIMING_RUNS}; overridden by --mode)",
278336
)
279337
trt_group.add_argument(
280338
"--plugin_libraries",

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ dependencies = [
4848
[project.optional-dependencies]
4949
onnx = [
5050
"cppimport",
51+
"cuda-python",
5152
"cupy-cuda12x; platform_machine != 'aarch64' and platform_system != 'Darwin'",
5253
"lief",
5354
"ml_dtypes",

tests/_test_utils/onnx/quantization/autotune/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525

2626
def _create_simple_conv_onnx_model():
2727
"""Build ONNX model: Input -> Conv -> Relu -> Output (minimal for autotuner tests)."""
28-
input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [32, 3, 224, 224])
28+
input_tensor = helper.make_tensor_value_info(
29+
"input", onnx.TensorProto.FLOAT, [64, 32, 224, 224]
30+
)
2931
output_tensor = helper.make_tensor_value_info(
30-
"output", onnx.TensorProto.FLOAT, [32, 64, 224, 224]
32+
"output", onnx.TensorProto.FLOAT, [64, 64, 224, 224]
3133
)
3234
conv_node = helper.make_node(
3335
"Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv"
@@ -40,7 +42,7 @@ def _create_simple_conv_onnx_model():
4042
[output_tensor],
4143
initializer=[
4244
helper.make_tensor(
43-
"conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3)
45+
"conv_weight", onnx.TensorProto.FLOAT, [64, 32, 3, 3], [0.1] * (64 * 32 * 3 * 3)
4446
)
4547
],
4648
)

tests/unit/onnx/quantization/autotune/test_autotune_config.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,17 @@
1414
# limitations under the License.
1515

1616
"""
17-
Tests for the Config class in the autotuner.
17+
Tests for the Config class and CLI mode presets in the autotuner.
1818
19-
Tests configuration parameter validation and defaults.
19+
Tests configuration parameter validation, defaults, and CLI --mode preset
20+
selection and explicit-flag precedence.
2021
"""
2122

23+
from modelopt.onnx.quantization.autotune.__main__ import (
24+
MODE_PRESETS,
25+
_get_autotune_parser,
26+
apply_mode_presets,
27+
)
2228
from modelopt.onnx.quantization.autotune.common import Config
2329

2430

@@ -95,3 +101,96 @@ def test_pattern_cache_params(self):
95101

96102
assert config.pattern_cache_minimum_distance == 3
97103
assert config.pattern_cache_max_entries_per_pattern == 10
104+
105+
106+
class TestModePresets:
107+
"""Test --mode preset selection and explicit-flag precedence."""
108+
109+
@staticmethod
110+
def _parse_cli(argv):
111+
"""Parse argv with the autotune CLI parser and apply mode presets."""
112+
parser = _get_autotune_parser()
113+
args = parser.parse_args(argv)
114+
apply_mode_presets(args)
115+
return args
116+
117+
def test_mode_quick_applies_preset_when_no_explicit_flags(self):
118+
"""With --mode quick and no explicit schemes/warmup/timing, preset values are used."""
119+
args = self._parse_cli(["--onnx_path", "model.onnx", "--mode", "quick"])
120+
preset = MODE_PRESETS["quick"]
121+
assert args.num_schemes == preset["schemes_per_region"]
122+
assert args.warmup_runs == preset["warmup_runs"]
123+
assert args.timing_runs == preset["timing_runs"]
124+
125+
def test_mode_default_applies_preset_when_no_explicit_flags(self):
126+
"""With --mode default and no explicit flags, preset values are used."""
127+
args = self._parse_cli(["--onnx_path", "model.onnx", "--mode", "default"])
128+
preset = MODE_PRESETS["default"]
129+
assert args.num_schemes == preset["schemes_per_region"]
130+
assert args.warmup_runs == preset["warmup_runs"]
131+
assert args.timing_runs == preset["timing_runs"]
132+
133+
def test_mode_extensive_applies_preset_when_no_explicit_flags(self):
134+
"""With --mode extensive and no explicit flags, preset values are used."""
135+
args = self._parse_cli(["--onnx_path", "model.onnx", "--mode", "extensive"])
136+
preset = MODE_PRESETS["extensive"]
137+
assert args.num_schemes == preset["schemes_per_region"]
138+
assert args.warmup_runs == preset["warmup_runs"]
139+
assert args.timing_runs == preset["timing_runs"]
140+
141+
def test_explicit_schemes_per_region_overrides_mode_preset(self):
142+
"""Explicit --schemes_per_region is kept even when it differs from preset."""
143+
args = self._parse_cli(
144+
["--onnx_path", "model.onnx", "--mode", "default", "--schemes_per_region", "99"]
145+
)
146+
assert args.num_schemes == 99
147+
assert args.warmup_runs == MODE_PRESETS["default"]["warmup_runs"]
148+
assert args.timing_runs == MODE_PRESETS["default"]["timing_runs"]
149+
150+
def test_explicit_default_value_not_overridden_by_mode(self):
151+
"""Explicit --schemes_per_region 30 (parser default) is not overridden by --mode default."""
152+
args = self._parse_cli(
153+
["--onnx_path", "model.onnx", "--mode", "default", "--schemes_per_region", "30"]
154+
)
155+
assert args.num_schemes == 30
156+
157+
def test_explicit_warmup_runs_overrides_mode_preset(self):
158+
"""Explicit --warmup_runs is kept and not overridden by preset."""
159+
args = self._parse_cli(
160+
["--onnx_path", "model.onnx", "--mode", "extensive", "--warmup_runs", "3"]
161+
)
162+
assert args.warmup_runs == 3
163+
assert args.num_schemes == MODE_PRESETS["extensive"]["schemes_per_region"]
164+
assert args.timing_runs == MODE_PRESETS["extensive"]["timing_runs"]
165+
166+
def test_explicit_timing_runs_overrides_mode_preset(self):
167+
"""Explicit --timing_runs is kept and not overridden by preset."""
168+
args = self._parse_cli(
169+
["--onnx_path", "model.onnx", "--mode", "quick", "--timing_runs", "7"]
170+
)
171+
assert args.timing_runs == 7
172+
assert args.num_schemes == MODE_PRESETS["quick"]["schemes_per_region"]
173+
assert args.warmup_runs == MODE_PRESETS["quick"]["warmup_runs"]
174+
175+
def test_multiple_explicit_overrides_mode_preset(self):
176+
"""Multiple explicit flags override only their respective preset values."""
177+
args = self._parse_cli(
178+
[
179+
"--onnx_path",
180+
"model.onnx",
181+
"--mode",
182+
"extensive",
183+
"--schemes_per_region",
184+
"10",
185+
"--timing_runs",
186+
"5",
187+
]
188+
)
189+
assert args.num_schemes == 10
190+
assert args.timing_runs == 5
191+
assert args.warmup_runs == MODE_PRESETS["extensive"]["warmup_runs"]
192+
193+
def test_short_flag_schemes_per_region_overrides_mode(self):
194+
"""Short form -s for schemes_per_region is treated as explicit and overrides preset."""
195+
args = self._parse_cli(["--onnx_path", "model.onnx", "--mode", "default", "-s", "25"])
196+
assert args.num_schemes == 25

0 commit comments

Comments
 (0)