-
Notifications
You must be signed in to change notification settings - Fork 353
Add ResNet50 support for torch_onnx quantization workflow #1263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
56a1d21
9810e1d
54cefb2
ba020d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import copy | ||
| import json | ||
| import re | ||
| import subprocess | ||
| import sys | ||
| import warnings | ||
| from pathlib import Path | ||
|
|
@@ -35,13 +36,17 @@ | |
| import modelopt.torch.quantization as mtq | ||
|
|
||
| """ | ||
| This script is used to quantize a timm model using dynamic quantization like MXFP8 or NVFP4, | ||
| or using auto quantization for optimal per-layer quantization. | ||
| Quantize a timm vision model and export to ONNX for TensorRT deployment. | ||
|
|
||
| Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes. | ||
|
|
||
| The script will: | ||
| 1. Given the model name, create a timm torch model. | ||
| 2. Quantize the torch model in MXFP8, NVFP4, INT4_AWQ, or AUTO mode. | ||
| 3. Export the quantized torch model to ONNX format. | ||
| 1. Load a pretrained timm model (e.g., ViT, Swin, ResNet). | ||
| 2. Quantize the model using the specified mode. For models with Conv2d layers, | ||
| Conv2d quantization is automatically overridden for TensorRT compatibility | ||
| (FP8 for MXFP8/NVFP4, INT8 for INT4_AWQ). | ||
| 3. Export the quantized model to ONNX with FP16 weights. | ||
| 4. Optionally evaluate accuracy on ImageNet-1k before and after quantization. | ||
| """ | ||
|
|
||
|
|
||
|
|
@@ -109,7 +114,8 @@ def filter_func(name): | |
| """Filter function to exclude certain layers from quantization.""" | ||
| pattern = re.compile( | ||
| r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|" | ||
| r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*" | ||
| r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|" | ||
| r"downsample|maxpool|global_pool).*" | ||
| ) | ||
| return pattern.match(name) is not None | ||
|
|
||
|
|
@@ -147,6 +153,40 @@ def load_calibration_data(model_name, data_size, batch_size, device, with_labels | |
| ) | ||
|
|
||
|
|
||
| def _calibrate_uncalibrated_quantizers(model, data_loader): | ||
| """Calibrate FP8 quantizers that weren't calibrated by mtq.quantize(). | ||
|
|
||
| When MXFP8/NVFP4 modes override Conv2d to FP8, the FP8 quantizers may not | ||
| be calibrated because the MXFP8/NVFP4 quantization pipeline skips standard | ||
| calibration. This function explicitly calibrates those uncalibrated quantizers. | ||
| """ | ||
| uncalibrated = [] | ||
| for _, module in model.named_modules(): | ||
| for attr_name in ("input_quantizer", "weight_quantizer"): | ||
| if not hasattr(module, attr_name): | ||
| continue | ||
| quantizer = getattr(module, attr_name) | ||
| if ( | ||
| quantizer.is_enabled | ||
| and not quantizer.block_sizes | ||
| and not hasattr(quantizer, "_amax") | ||
| ): | ||
| quantizer.enable_calib() | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| uncalibrated.append(quantizer) | ||
|
|
||
| if not uncalibrated: | ||
| return | ||
|
|
||
| model.eval() | ||
| with torch.no_grad(): | ||
| for batch in data_loader: | ||
| model(batch) | ||
|
|
||
| for quantizer in uncalibrated: | ||
| quantizer.disable_calib() | ||
| quantizer.load_calib_amax() | ||
|
|
||
|
|
||
| def quantize_model(model, config, data_loader=None): | ||
| """Quantize the model using the given config and calibration data.""" | ||
| if data_loader is not None: | ||
|
|
@@ -159,6 +199,10 @@ def forward_loop(model): | |
| else: | ||
| quantized_model = mtq.quantize(model, config) | ||
|
|
||
| # Calibrate any FP8 override quantizers that weren't calibrated by mtq.quantize() | ||
| if data_loader is not None: | ||
| _calibrate_uncalibrated_quantizers(quantized_model, data_loader) | ||
|
|
||
| mtq.disable_quantizer(quantized_model, filter_func) | ||
| return quantized_model | ||
|
|
||
|
|
@@ -185,6 +229,38 @@ def _disable_inplace_relu(model): | |
| module.inplace = False | ||
|
|
||
|
|
||
| def _override_conv2d_to_fp8(model, data_loader): | ||
| """Override Conv2d layers with NVFP4/MXFP8 block quantization to FP8. | ||
|
|
||
| TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors. | ||
| This overrides Conv2d block quantizers to FP8 per-tensor and calibrates them. | ||
| """ | ||
| overridden = [] | ||
| for _, module in model.named_modules(): | ||
| if not isinstance(module, torch.nn.Conv2d): | ||
| continue | ||
| for attr_name in ("input_quantizer", "weight_quantizer"): | ||
| if not hasattr(module, attr_name): | ||
| continue | ||
| quantizer = getattr(module, attr_name) | ||
| if quantizer.is_enabled and quantizer.block_sizes: | ||
| # Override to FP8 per-tensor | ||
| quantizer.block_sizes = None | ||
| quantizer._num_bits = (4, 3) | ||
| quantizer._axis = None | ||
|
Comment on lines
+248
to
+250
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Check if there's a public API for setting quantizer num_bits/axis
rg -n "def.*num_bits|def.*axis" --type=py modelopt/torch/quantization/Repository: NVIDIA/Model-Optimizer Length of output: 3259 🏁 Script executed: #!/bin/bash
# Get the file and examine the problematic lines with context
cat -n examples/torch_onnx/torch_quant_to_onnx.py | sed -n '235,260p'Repository: NVIDIA/Model-Optimizer Length of output: 1289 🏁 Script executed: #!/bin/bash
# Examine the public num_bits and axis property implementations
cat -n modelopt/torch/quantization/nn/modules/tensor_quantizer.py | sed -n '272,285p'
cat -n modelopt/torch/quantization/nn/modules/tensor_quantizer.py | sed -n '372,385p'Repository: NVIDIA/Model-Optimizer Length of output: 967 Use public property setters for At lines 249-250, directly assigning to Change: quantizer._num_bits = (4, 3)
quantizer._axis = NoneTo: quantizer.num_bits = (4, 3)
quantizer.axis = NoneLine 248 correctly uses the public 🤖 Prompt for AI Agents |
||
| quantizer.enable_calib() | ||
| overridden.append(quantizer) | ||
|
|
||
| if overridden: | ||
| model.eval() | ||
| with torch.no_grad(): | ||
| for batch in data_loader: | ||
| model(batch["image"]) | ||
| for quantizer in overridden: | ||
| quantizer.disable_calib() | ||
| quantizer.load_calib_amax() | ||
|
Comment on lines
+232
to
+261
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Apply the Conv2d TRT restriction before the auto-quant search.
Also applies to: 307-309 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def auto_quantize_model( | ||
| model, | ||
| data_loader, | ||
|
|
@@ -233,6 +309,10 @@ def auto_quantize_model( | |
| verbose=True, | ||
| ) | ||
|
|
||
| # Override Conv2d layers that got NVFP4/MXFP8 to FP8 for TRT compatibility. | ||
| # TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors. | ||
| _override_conv2d_to_fp8(quantized_model, data_loader) | ||
|
|
||
| # Disable quantization for specified layers | ||
| mtq.disable_quantizer(quantized_model, filter_func) | ||
|
|
||
|
|
@@ -320,6 +400,11 @@ def main(): | |
| default=128, | ||
| help="Number of scoring steps for auto quantization. Default is 128.", | ||
| ) | ||
| parser.add_argument( | ||
| "--trt_build", | ||
| action="store_true", | ||
| help="Build a TensorRT engine from the exported ONNX model using trtexec.", | ||
| ) | ||
| parser.add_argument( | ||
| "--no_pretrained", | ||
| action="store_true", | ||
|
|
@@ -378,18 +463,18 @@ def main(): | |
| args.num_score_steps, | ||
| ) | ||
| else: | ||
| # Standard quantization - only load calibration data if needed | ||
| # Standard quantization - load calibration data | ||
| # Note: MXFP8 is dynamic and does not need calibration itself, but when | ||
| # Conv2d layers are overridden to FP8 (for TRT compatibility), those FP8 | ||
| # quantizers require calibration data. | ||
| config = get_quant_config(args.quantize_mode) | ||
| if args.quantize_mode == "mxfp8": | ||
| data_loader = None | ||
| else: | ||
| data_loader = load_calibration_data( | ||
| args.timm_model_name, | ||
| args.calibration_data_size, | ||
| args.batch_size, | ||
| device, | ||
| with_labels=False, | ||
| ) | ||
| data_loader = load_calibration_data( | ||
| args.timm_model_name, | ||
| args.calibration_data_size, | ||
| args.batch_size, | ||
| device, | ||
| with_labels=False, | ||
| ) | ||
|
|
||
| quantized_model = quantize_model(model, config, data_loader) | ||
|
|
||
|
|
@@ -421,6 +506,26 @@ def main(): | |
|
|
||
| print(f"Quantized ONNX model is saved to {args.onnx_save_path}") | ||
|
|
||
| if args.trt_build: | ||
| build_trt_engine(args.onnx_save_path) | ||
|
|
||
|
|
||
| def build_trt_engine(onnx_path): | ||
| """Build a TensorRT engine from the exported ONNX model using trtexec.""" | ||
| cmd = [ | ||
| "trtexec", | ||
| f"--onnx={onnx_path}", | ||
| "--stronglyTyped", | ||
| "--builderOptimizationLevel=4", | ||
| ] | ||
| print(f"\nBuilding TensorRT engine: {' '.join(cmd)}") | ||
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) | ||
| if result.returncode != 0: | ||
| raise RuntimeError( | ||
| f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}" | ||
| ) | ||
| print("TensorRT engine build succeeded.") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don’t advertise
INT4_AWQas supported end-to-end here.The PR objectives still call out
INT4_AWQas a known limitation, but this docstring now groups it with the working modes. Please caveat or remove it here so users do not assume this example is expected to succeed in that mode.✏️ Suggested wording
🤖 Prompt for AI Agents