Skip to content

Commit 56a1d21

Browse files
ajrasaneclaude
andcommitted
Add ResNet50 support for torch_onnx quantization workflow
Add end-to-end support for ResNet50 (Conv2d-heavy model) in the torch_onnx quantization → ONNX export → TRT engine pipeline. Key fixes for Conv2d-heavy models: - Disable FP8 Conv2d weight quantizers during ONNX export to avoid TorchScript exporter's "kernel of unknown shape" error (FP8 DequantizeLinear produces dynamic-shape outputs incompatible with Conv2d's static kernel requirement) - Disable autocast for FP8/INT8 quantized models during export (prevents dynamic-shape kernels from autocast-induced FP16 casting) - Fix configure_linear_module_onnx_quantizers to handle all modules with block quantization (not just nn.Linear), fixing NVFP4/MXFP8 export for models with quantized non-Linear modules like MaxPool2d - Add calibration step for FP8 override quantizers that aren't calibrated by mtq.quantize() in MXFP8/NVFP4 modes - Override Conv2d block quantizers to FP8 in auto mode for TRT compat - Add maxpool and global_pool to filter_func (TRT DynamicQuantize requires 2D/3D input, but pooling layers operate on 4D tensors) - Always load calibration data (MXFP8 Conv2d FP8 overrides need it) Signed-off-by: ajrasane <arasane@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent b6c6ec3 commit 56a1d21

File tree

6 files changed

+142
-17
lines changed

6 files changed

+142
-17
lines changed

examples/torch_onnx/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ python torch_quant_to_onnx.py \
307307
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) |||||||
308308
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) |||||||
309309
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) |||||||
310+
| [resnet50](https://huggingface.co/timm/resnet50.a1_in1k) ||||| ||
310311

311312
## Resources
312313

examples/torch_onnx/torch_quant_to_onnx.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def filter_func(name):
109109
"""Filter function to exclude certain layers from quantization."""
110110
pattern = re.compile(
111111
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|"
112-
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*"
112+
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|"
113+
r"downsample|maxpool|global_pool).*"
113114
)
114115
return pattern.match(name) is not None
115116

@@ -147,6 +148,36 @@ def load_calibration_data(model_name, data_size, batch_size, device, with_labels
147148
)
148149

149150

151+
def _calibrate_uncalibrated_quantizers(model, data_loader):
152+
"""Calibrate FP8 quantizers that weren't calibrated by mtq.quantize().
153+
154+
When MXFP8/NVFP4 modes override Conv2d to FP8, the FP8 quantizers may not
155+
be calibrated because the MXFP8/NVFP4 quantization pipeline skips standard
156+
calibration. This function explicitly calibrates those uncalibrated quantizers.
157+
"""
158+
uncalibrated = []
159+
for _, module in model.named_modules():
160+
for attr_name in ("input_quantizer", "weight_quantizer"):
161+
if not hasattr(module, attr_name):
162+
continue
163+
quantizer = getattr(module, attr_name)
164+
if quantizer.is_enabled and not quantizer.block_sizes and not hasattr(quantizer, "_amax"):
165+
quantizer.enable_calib()
166+
uncalibrated.append(quantizer)
167+
168+
if not uncalibrated:
169+
return
170+
171+
model.eval()
172+
with torch.no_grad():
173+
for batch in data_loader:
174+
model(batch)
175+
176+
for quantizer in uncalibrated:
177+
quantizer.disable_calib()
178+
quantizer.load_calib_amax()
179+
180+
150181
def quantize_model(model, config, data_loader=None):
151182
"""Quantize the model using the given config and calibration data."""
152183
if data_loader is not None:
@@ -159,6 +190,10 @@ def forward_loop(model):
159190
else:
160191
quantized_model = mtq.quantize(model, config)
161192

193+
# Calibrate any FP8 override quantizers that weren't calibrated by mtq.quantize()
194+
if data_loader is not None:
195+
_calibrate_uncalibrated_quantizers(quantized_model, data_loader)
196+
162197
mtq.disable_quantizer(quantized_model, filter_func)
163198
return quantized_model
164199

@@ -185,6 +220,38 @@ def _disable_inplace_relu(model):
185220
module.inplace = False
186221

187222

223+
def _override_conv2d_to_fp8(model, data_loader):
224+
"""Override Conv2d layers with NVFP4/MXFP8 block quantization to FP8.
225+
226+
TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
227+
This overrides Conv2d block quantizers to FP8 per-tensor and calibrates them.
228+
"""
229+
overridden = []
230+
for _, module in model.named_modules():
231+
if not isinstance(module, torch.nn.Conv2d):
232+
continue
233+
for attr_name in ("input_quantizer", "weight_quantizer"):
234+
if not hasattr(module, attr_name):
235+
continue
236+
quantizer = getattr(module, attr_name)
237+
if quantizer.is_enabled and quantizer.block_sizes:
238+
# Override to FP8 per-tensor
239+
quantizer.block_sizes = None
240+
quantizer._num_bits = (4, 3)
241+
quantizer._axis = None
242+
quantizer.enable_calib()
243+
overridden.append(quantizer)
244+
245+
if overridden:
246+
model.eval()
247+
with torch.no_grad():
248+
for batch in data_loader:
249+
model(batch["image"])
250+
for quantizer in overridden:
251+
quantizer.disable_calib()
252+
quantizer.load_calib_amax()
253+
254+
188255
def auto_quantize_model(
189256
model,
190257
data_loader,
@@ -233,6 +300,10 @@ def auto_quantize_model(
233300
verbose=True,
234301
)
235302

303+
# Override Conv2d layers that got NVFP4/MXFP8 to FP8 for TRT compatibility.
304+
# TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
305+
_override_conv2d_to_fp8(quantized_model, data_loader)
306+
236307
# Disable quantization for specified layers
237308
mtq.disable_quantizer(quantized_model, filter_func)
238309

@@ -378,18 +449,18 @@ def main():
378449
args.num_score_steps,
379450
)
380451
else:
381-
# Standard quantization - only load calibration data if needed
452+
# Standard quantization - load calibration data
453+
# Note: MXFP8 is dynamic and does not need calibration itself, but when
454+
# Conv2d layers are overridden to FP8 (for TRT compatibility), those FP8
455+
# quantizers require calibration data.
382456
config = get_quant_config(args.quantize_mode)
383-
if args.quantize_mode == "mxfp8":
384-
data_loader = None
385-
else:
386-
data_loader = load_calibration_data(
387-
args.timm_model_name,
388-
args.calibration_data_size,
389-
args.batch_size,
390-
device,
391-
with_labels=False,
392-
)
457+
data_loader = load_calibration_data(
458+
args.timm_model_name,
459+
args.calibration_data_size,
460+
args.batch_size,
461+
device,
462+
with_labels=False,
463+
)
393464

394465
quantized_model = quantize_model(model, config, data_loader)
395466

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Utility functions related to Onnx."""
1717

1818
import base64
19+
import contextlib
1920
import inspect
2021
import json
2122
import logging
@@ -402,6 +403,29 @@ def is_fp8_quantized(model: nn.Module) -> bool:
402403
return False
403404

404405

406+
@contextlib.contextmanager
407+
def _disable_fp8_conv_weight_quantizers(model: nn.Module):
408+
"""Temporarily disable FP8 weight quantizers on Conv layers during ONNX export.
409+
410+
The TorchScript ONNX exporter requires static kernel shapes for Conv operations,
411+
but FP8 weight quantization (TRT_FP8QuantizeLinear -> TRT_FP8DequantizeLinear)
412+
produces dynamic-shape outputs that break this requirement. Disabling Conv weight
413+
quantizers during export allows the Conv to export with static-shape FP16/FP32
414+
weights. Conv activations still have FP8 QDQ nodes (input quantizers remain enabled).
415+
"""
416+
disabled = []
417+
for _, module in model.named_modules():
418+
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
419+
if hasattr(module, "weight_quantizer") and module.weight_quantizer.is_enabled:
420+
module.weight_quantizer.disable()
421+
disabled.append(module)
422+
try:
423+
yield
424+
finally:
425+
for module in disabled:
426+
module.weight_quantizer.enable()
427+
428+
405429
def quantize_weights(model: nn.Module, onnx_model: onnx.ModelProto) -> onnx.ModelProto:
406430
"""Real quantizes the weights in the onnx model.
407431
@@ -522,7 +546,11 @@ def get_onnx_bytes_and_metadata(
522546
input_none_names = list(set(tree_spec_input.names) - set(input_names))
523547

524548
use_torch_autocast = not (
525-
is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
549+
is_fp4_quantized(model)
550+
or is_mxfp8_quantized(model)
551+
or is_fp8_quantized(model)
552+
or is_int8_quantized(model)
553+
or weights_dtype == "fp32"
526554
)
527555
autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()
528556

@@ -556,7 +584,14 @@ def get_onnx_bytes_and_metadata(
556584
if is_fp4_quantized(model) or is_mxfp8_quantized(model)
557585
else nullcontext()
558586
)
559-
with torch.inference_mode(), autocast, quantizer_context:
587+
# Disable FP8 Conv weight quantizers: TorchScript ONNX exporter requires static
588+
# kernel shapes, but FP8 DequantizeLinear produces dynamic shapes.
589+
conv_wq_context = (
590+
_disable_fp8_conv_weight_quantizers(model)
591+
if is_fp8_quantized(model)
592+
else nullcontext()
593+
)
594+
with torch.inference_mode(), autocast, quantizer_context, conv_wq_context:
560595
additional_kwargs = {}
561596
if not dynamo_export:
562597
additional_kwargs["dynamic_axes"] = dynamic_axes
@@ -598,7 +633,12 @@ def get_onnx_bytes_and_metadata(
598633
onnx_opt_graph = qdq_to_dq(onnx_opt_graph)
599634

600635
if weights_dtype in ["fp16", "bf16"]:
601-
if is_int4_quantized(model) or is_mxfp8_quantized(model) or is_fp8_quantized(model):
636+
if (
637+
is_int4_quantized(model)
638+
or is_mxfp8_quantized(model)
639+
or is_fp8_quantized(model)
640+
or is_int8_quantized(model)
641+
):
602642
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
603643
onnx_opt_graph = convert_float_to_float16(
604644
onnx_opt_graph,

modelopt/torch/quantization/export_onnx.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -656,9 +656,20 @@ def export_fp4(
656656

657657
@contextlib.contextmanager
658658
def configure_linear_module_onnx_quantizers(model):
659-
"""Sets the onnx export attributes for the given model."""
659+
"""Sets the onnx export attributes for the given model.
660+
661+
For modules with block quantization (NVFP4/MXFP8):
662+
- Weight quantizers use "static" export (TRT_FP4QDQ for NVFP4, DQ-only for MXFP8)
663+
- Input/activation quantizers use "dynamic" export (TRT_FP4DynamicQuantize, etc.)
664+
665+
This must be set for ALL modules with block quantization, not just nn.Linear,
666+
because models like ResNet have non-Linear modules (e.g., MaxPool2d) with NVFP4/MXFP8
667+
input quantizers that would otherwise default to the static path and produce
668+
TRT_FP4QDQ nodes on activations (which the NVFP4 exporter cannot handle).
669+
"""
660670
for _, module in model.named_modules():
661-
if isinstance(module, torch.nn.Linear):
671+
if hasattr(module, "input_quantizer") and module.input_quantizer.block_sizes:
662672
module.input_quantizer._onnx_quantizer_type = "dynamic"
673+
if hasattr(module, "weight_quantizer") and module.weight_quantizer.block_sizes:
663674
module.weight_quantizer._onnx_quantizer_type = "static"
664675
yield

tests/_test_utils/torch/vision_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def get_model_and_input(on_gpu: bool = False):
117117
# "dm_nfnet_f0",
118118
"efficientnet_b0",
119119
"swin_tiny_patch4_window7_224",
120+
"resnet50",
120121
],
121122
_create_timm_fn,
122123
),

tests/examples/torch_onnx/test_torch_quant_to_onnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"vit_tiny": ("vit_tiny_patch16_224", '{"depth": 1}'),
2929
"swin_tiny": ("swin_tiny_patch4_window7_224", '{"depths": [1, 1, 1, 1]}'),
3030
"swinv2_tiny": ("swinv2_tiny_window8_256", '{"depths": [1, 1, 1, 1]}'),
31+
"resnet50": ("resnet50", None),
3132
}
3233

3334
# Builder optimization level: 4 for low-bit modes, 3 otherwise

0 commit comments

Comments
 (0)