Skip to content

Commit 15a24d8

Browse files
committed
Apply code quality fixes from CodeRabbit review
- Use public `quantizer.amax is None` instead of `hasattr(quantizer, "_amax")` for detecting uncalibrated quantizers. - Refactor `load_calibration_data` to accept a model instance so calibration transforms match the exact model being quantized (respects --no_pretrained and --model_kwargs). - Reorder `mtq.disable_quantizer` before `_calibrate_uncalibrated_quantizers` to avoid calibrating quantizers that will be filtered out. - Wrap `trtexec` subprocess call in try/except for FileNotFoundError / TimeoutExpired with clearer error messages. - Add try/finally state restoration to `configure_linear_module_onnx_quantizers` so `_onnx_quantizer_type` is reset after ONNX export. - Caveat INT4_AWQ in docstring and argparse description to note it's quantize/export-only and not compatible with --trt_build. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 6419b34 commit 15a24d8

2 files changed

Lines changed: 53 additions & 21 deletions

File tree

examples/torch_onnx/torch_quant_to_onnx.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838
"""
3939
Quantize a timm vision model and export to ONNX for TensorRT deployment.
4040
41-
Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes.
41+
Supports FP8, INT8, MXFP8, NVFP4, and AUTO (mixed-precision) quantization modes end-to-end
42+
(quantize + ONNX export + TRT build). INT4_AWQ is quantize/export-only; it is not compatible
43+
with ``--trt_build``.
4244
4345
The script will:
4446
1. Load a pretrained timm model (e.g., ViT, Swin, ResNet).
@@ -88,7 +90,11 @@
8890

8991
# Auto-quantize format configs that use block quantization and need Conv2d overrides for TRT.
9092
# TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
91-
_NEEDS_FP8_CONV_OVERRIDE: set[str] = {"NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", "MXFP8_DEFAULT_CFG"}
93+
_NEEDS_FP8_CONV_OVERRIDE: set[str] = {
94+
"NVFP4_AWQ_LITE_CFG",
95+
"NVFP4_DEFAULT_CFG",
96+
"MXFP8_DEFAULT_CFG",
97+
}
9298
_NEEDS_INT8_CONV_OVERRIDE: set[str] = {"INT4_AWQ_CFG"}
9399

94100

@@ -125,19 +131,20 @@ def filter_func(name):
125131
return pattern.match(name) is not None
126132

127133

128-
def load_calibration_data(model_name, data_size, batch_size, device, with_labels=False):
134+
def load_calibration_data(model, data_size, batch_size, device, with_labels=False):
129135
"""Load and prepare calibration data.
130136
131137
Args:
132-
model_name: Name of the timm model
138+
model: The timm model being quantized; used to derive the calibration transforms so the
139+
data pipeline matches the exact model config (respects --no_pretrained and
140+
--model_kwargs).
133141
data_size: Number of samples to load
134142
batch_size: Batch size for data loader
135143
device: Device to load data to
136144
with_labels: If True, return dict with 'image' and 'label' keys (for auto_quantize)
137145
If False, return just the images (for standard quantize)
138146
"""
139147
dataset = load_dataset("zh-plus/tiny-imagenet")
140-
model = timm.create_model(model_name, pretrained=True, num_classes=1000)
141148
data_config = timm.data.resolve_model_data_config(model)
142149
transforms = timm.data.create_transform(**data_config, is_training=False)
143150

@@ -171,11 +178,7 @@ def _calibrate_uncalibrated_quantizers(model, data_loader):
171178
if not hasattr(module, attr_name):
172179
continue
173180
quantizer = getattr(module, attr_name)
174-
if (
175-
quantizer.is_enabled
176-
and not quantizer.block_sizes
177-
and not hasattr(quantizer, "_amax")
178-
):
181+
if quantizer.is_enabled and not quantizer.block_sizes and quantizer.amax is None:
179182
quantizer.enable_calib()
180183
uncalibrated.append(quantizer)
181184

@@ -204,11 +207,14 @@ def forward_loop(model):
204207
else:
205208
quantized_model = mtq.quantize(model, config)
206209

207-
# Calibrate any FP8 override quantizers that weren't calibrated by mtq.quantize()
210+
# Disable filtered quantizers BEFORE calibrating override quantizers so we don't
211+
# waste time calibrating quantizers that are about to be turned off.
212+
mtq.disable_quantizer(quantized_model, filter_func)
213+
214+
# Calibrate any FP8 override quantizers that weren't calibrated by mtq.quantize().
208215
if data_loader is not None:
209216
_calibrate_uncalibrated_quantizers(quantized_model, data_loader)
210217

211-
mtq.disable_quantizer(quantized_model, filter_func)
212218
return quantized_model
213219

214220

@@ -305,7 +311,10 @@ def get_model_input_shape(model):
305311

306312
def main():
307313
parser = argparse.ArgumentParser(
308-
description="Quantize timm models to FP8, MXFP8, INT8, NVFP4, INT4_AWQ, or use AUTO quantization"
314+
description=(
315+
"Quantize timm models to FP8, MXFP8, INT8, NVFP4, or use AUTO quantization. "
316+
"INT4_AWQ is supported for quantize/export only and is not compatible with --trt_build."
317+
)
309318
)
310319

311320
# Model hyperparameters
@@ -424,7 +433,7 @@ def main():
424433
if args.quantize_mode == "auto":
425434
# Auto quantization requires labels for loss computation
426435
data_loader = load_calibration_data(
427-
args.timm_model_name,
436+
model,
428437
args.calibration_data_size,
429438
args.batch_size,
430439
device,
@@ -446,7 +455,7 @@ def main():
446455
# quantizers require calibration data.
447456
config = get_quant_config(args.quantize_mode)
448457
data_loader = load_calibration_data(
449-
args.timm_model_name,
458+
model,
450459
args.calibration_data_size,
451460
args.batch_size,
452461
device,
@@ -496,7 +505,14 @@ def build_trt_engine(onnx_path):
496505
"--builderOptimizationLevel=4",
497506
]
498507
print(f"\nBuilding TensorRT engine: {' '.join(cmd)}")
499-
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
508+
try:
509+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
510+
except FileNotFoundError as e:
511+
raise RuntimeError(
512+
"trtexec not found on PATH; install TensorRT or drop --trt_build."
513+
) from e
514+
except subprocess.TimeoutExpired as e:
515+
raise RuntimeError(f"trtexec timed out building {onnx_path} after 600s.") from e
500516
if result.returncode != 0:
501517
raise RuntimeError(
502518
f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}"

modelopt/torch/quantization/export_onnx.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -667,9 +667,25 @@ def configure_linear_module_onnx_quantizers(model):
667667
input quantizers that would otherwise default to the static path and produce
668668
TRT_FP4QDQ nodes on activations (which the NVFP4 exporter cannot handle).
669669
"""
670+
sentinel = object()
671+
originals: list[tuple] = []
670672
for _, module in model.named_modules():
671-
if hasattr(module, "input_quantizer") and module.input_quantizer.block_sizes:
672-
module.input_quantizer._onnx_quantizer_type = "dynamic"
673-
if hasattr(module, "weight_quantizer") and module.weight_quantizer.block_sizes:
674-
module.weight_quantizer._onnx_quantizer_type = "static"
675-
yield
673+
for attr_name, new_value in (
674+
("input_quantizer", "dynamic"),
675+
("weight_quantizer", "static"),
676+
):
677+
quantizer = getattr(module, attr_name, None)
678+
if quantizer is None or not quantizer.block_sizes:
679+
continue
680+
original = getattr(quantizer, "_onnx_quantizer_type", sentinel)
681+
originals.append((quantizer, original))
682+
quantizer._onnx_quantizer_type = new_value
683+
try:
684+
yield
685+
finally:
686+
for quantizer, original in originals:
687+
if original is sentinel:
688+
if hasattr(quantizer, "_onnx_quantizer_type"):
689+
delattr(quantizer, "_onnx_quantizer_type")
690+
else:
691+
quantizer._onnx_quantizer_type = original

0 commit comments

Comments
 (0)