|
33 | 33 | DEFAULT_TIMING_CACHE = str(Path(tempfile.gettempdir()) / "trtexec_timing.cache") |
34 | 34 | DEFAULT_WARMUP_RUNS = 5 |
35 | 35 | DEFAULT_TIMING_RUNS = 20 |
| 36 | +MODE_PRESETS = { |
| 37 | + "quick": {"schemes_per_region": 30, "warmup_runs": 10, "timing_runs": 50}, |
| 38 | + "default": {"schemes_per_region": 50, "warmup_runs": 50, "timing_runs": 100}, |
| 39 | + "extensive": {"schemes_per_region": 200, "warmup_runs": 50, "timing_runs": 200}, |
| 40 | +} |
| 41 | + |
| 42 | + |
| 43 | +class _StoreWithExplicitFlag(argparse.Action): |
| 44 | + """Store the value and set an 'explicit' flag on the namespace so mode presets do not override.""" |
| 45 | + |
| 46 | + def __init__(self, explicit_attr: str, *args, **kwargs): |
| 47 | + self._explicit_attr = explicit_attr |
| 48 | + super().__init__(*args, **kwargs) |
| 49 | + |
| 50 | + def __call__(self, parser, namespace, values, option_string=None): |
| 51 | + setattr(namespace, self.dest, values) |
| 52 | + setattr(namespace, self._explicit_attr, True) |
| 53 | + |
| 54 | + |
| 55 | +def apply_mode_presets(args) -> None: |
| 56 | + """Apply --mode preset to schemes_per_region, warmup_runs, timing_runs. |
| 57 | +
|
| 58 | + Only applies preset for an option when that option was not explicitly set on the |
| 59 | + command line (explicit flags override the preset even when the value equals the default). |
| 60 | + """ |
| 61 | + if args.mode not in MODE_PRESETS: |
| 62 | + return |
| 63 | + preset = MODE_PRESETS[args.mode] |
| 64 | + if not getattr(args, "_explicit_num_schemes", False): |
| 65 | + args.num_schemes = preset["schemes_per_region"] |
| 66 | + if not getattr(args, "_explicit_warmup_runs", False): |
| 67 | + args.warmup_runs = preset["warmup_runs"] |
| 68 | + if not getattr(args, "_explicit_timing_runs", False): |
| 69 | + args.timing_runs = preset["timing_runs"] |
36 | 70 |
|
37 | 71 |
|
38 | 72 | def validate_file_path(path: str | None, description: str) -> Path | None: |
@@ -94,6 +128,7 @@ def run_autotune() -> int: |
94 | 128 | - 130: Interrupted by user (Ctrl+C) |
95 | 129 | """ |
96 | 130 | args = _get_autotune_parser().parse_args() |
| 131 | + apply_mode_presets(args) |
97 | 132 | model_path = validate_file_path(args.onnx_path, "Model file") |
98 | 133 | validate_file_path(args.qdq_baseline, "QDQ baseline model") |
99 | 134 | output_dir = Path(args.output_dir) |
@@ -167,6 +202,12 @@ def _get_autotune_parser() -> argparse.ArgumentParser: |
167 | 202 | # Basic usage |
168 | 203 | python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx |
169 | 204 |
|
| 205 | + # Quick mode (fewer schemes and benchmark runs for fast iteration) |
| 206 | + python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx --mode quick |
| 207 | +
|
| 208 | + # Extensive mode (more schemes and runs for thorough tuning) |
| 209 | + python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx --mode extensive |
| 210 | +
|
170 | 211 | # Import patterns from QDQ baseline model |
171 | 212 | python -m modelopt.onnx.quantization.autotune \\ |
172 | 213 | --onnx_path model.onnx --qdq_baseline baseline.onnx |
@@ -198,13 +239,26 @@ def _get_autotune_parser() -> argparse.ArgumentParser: |
198 | 239 |
|
199 | 240 | # Autotuning Strategy |
200 | 241 | strategy_group = parser.add_argument_group("Autotuning Strategy") |
| 242 | + strategy_group.add_argument( |
| 243 | + "--mode", |
| 244 | + type=str, |
| 245 | + default="default", |
| 246 | + choices=["quick", "default", "extensive"], |
| 247 | + help="Preset for schemes_per_region, warmup_runs, and timing_runs. " |
| 248 | + "'quick': fewer schemes/runs for fast iteration; " |
| 249 | + "'default': balanced; " |
| 250 | + "'extensive': more schemes/runs for thorough tuning. " |
| 251 | + "Explicit --schemes_per_region, --warmup_runs, --timing_runs override the preset.", |
| 252 | + ) |
201 | 253 | strategy_group.add_argument( |
202 | 254 | "--schemes_per_region", |
203 | 255 | "-s", |
204 | 256 | type=int, |
205 | 257 | default=DEFAULT_NUM_SCHEMES, |
206 | 258 | dest="num_schemes", |
207 | | - help=f"Number of schemes to test per region (default: {DEFAULT_NUM_SCHEMES})", |
| 259 | + action=_StoreWithExplicitFlag, |
| 260 | + explicit_attr="_explicit_num_schemes", |
| 261 | + help=f"Number of schemes to test per region (default: {DEFAULT_NUM_SCHEMES}; overridden by --mode)", |
208 | 262 | ) |
209 | 263 | strategy_group.add_argument( |
210 | 264 | "--pattern_cache", |
@@ -268,13 +322,17 @@ def _get_autotune_parser() -> argparse.ArgumentParser: |
268 | 322 | "--warmup_runs", |
269 | 323 | type=int, |
270 | 324 | default=DEFAULT_WARMUP_RUNS, |
271 | | - help=f"Number of warmup runs (default: {DEFAULT_WARMUP_RUNS})", |
| 325 | + action=_StoreWithExplicitFlag, |
| 326 | + explicit_attr="_explicit_warmup_runs", |
| 327 | + help=f"Number of warmup runs (default: {DEFAULT_WARMUP_RUNS}; overridden by --mode)", |
272 | 328 | ) |
273 | 329 | trt_group.add_argument( |
274 | 330 | "--timing_runs", |
275 | 331 | type=int, |
276 | 332 | default=DEFAULT_TIMING_RUNS, |
277 | | - help=f"Number of timing runs (default: {DEFAULT_TIMING_RUNS})", |
| 333 | + action=_StoreWithExplicitFlag, |
| 334 | + explicit_attr="_explicit_timing_runs", |
| 335 | + help=f"Number of timing runs (default: {DEFAULT_TIMING_RUNS}; overridden by --mode)", |
278 | 336 | ) |
279 | 337 | trt_group.add_argument( |
280 | 338 | "--plugin_libraries", |
|
0 commit comments