Skip to content

Commit 202c3d3

Browse files
ajrasaneclaude
andauthored
Add SwinTransformer support for torch_onnx quantization workflow (#1235)
## Summary - Enable end-to-end quantize → ONNX export → TRT engine pipeline for SwinTransformer models (v1 and v2) across FP8, INT8, MXFP8, NVFP4, and auto precision modes - Add Conv2d quantization overrides for TRT compatibility (TRT only supports FP8/INT8 for convolutions) - Fix FP8 LayerNorm type mismatch in TRT stronglyTyped mode by adding `LayerNormalization` to `change_casts_to_fp16` - Fix `cast_initializer_to_dtype` crash when node has no initializer inputs - Simplify `download_example_onnx.py` to a single `--timm_model_name` (required) flag, removing redundant `--vit` and `--llama` flags - Add vision model support matrix to README (ViT, Swin, SwinV2) - Rewrite tests: parametrize over (ViT, Swin, SwinV2) × (fp8, int8, mxfp8, nvfp4, auto) with TRT engine build verification ## Test plan - [ ] `python -m pytest tests/examples/torch_onnx/test_torch_quant_to_onnx.py -v` — 15 tests (3 models × 5 modes), all pass - [ ] Verified Swin accuracy on ImageNet-1k across all precisions (FP8: 81.29%, INT8: 81.12%, MXFP8: 81.32%, NVFP4: 80.79%, Auto: 80.84% TRT top-1 vs 81.37% base) - [ ] INT4_AWQ deferred (TODO in test file) — requires INT4 exporter changes for non-MatMul/Gemm consumer patterns 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * ONNX export supports arbitrary timm vision models with auto device selection and new CLI options (--timm_model_name, --model_kwargs, --no_pretrained); batch-size/input sizing is now model-generic. * **Bug Fixes** * Expanded FP16/BF16 cast handling to additional ONNX ops. * Disabled inplace ReLU before auto-quantization to avoid incorrect transforms. * Conv2d quantization overrides added for improved TensorRT compatibility. * Safer handling when initializers are missing during dtype casting. * **Documentation** * README updated with supported models table, quantization mappings, and example CLI usage. * **Tests** * Tests expanded to multiple architectures/quant modes and now verify TensorRT engine build. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6403389 commit 202c3d3

File tree

8 files changed

+189
-91
lines changed

8 files changed

+189
-91
lines changed

examples/onnx_ptq/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ Most of the examples in this doc use `vit_base_patch16_224.onnx` as the input mo
5656

5757
```bash
5858
python download_example_onnx.py \
59-
--vit \
59+
--timm_model_name=vit_base_patch16_224 \
6060
--onnx_save_path=vit_base_patch16_224.onnx \
6161
--fp16 # <Optional, if the desired output ONNX precision is FP16>
6262
```

examples/onnx_ptq/download_example_onnx.py

Lines changed: 20 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import argparse
1717
import os
18-
import subprocess
1918

2019
import timm
2120
import torch
@@ -46,14 +45,10 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
4645
parser = argparse.ArgumentParser(description="Download and export example models to ONNX.")
4746

4847
parser.add_argument(
49-
"--vit",
50-
action="store_true",
51-
help="Export timm/vit_base_patch16_224 model to ONNX.",
52-
)
53-
parser.add_argument(
54-
"--llama",
55-
action="store_true",
56-
help="Export meta-llama/Llama-3.1-8B-Instruct to ONNX with KV cache.",
48+
"--timm_model_name",
49+
type=str,
50+
required=True,
51+
help="Export any timm model to ONNX (e.g., vit_base_patch16_224, swin_tiny_patch4_window7_224).",
5752
)
5853
parser.add_argument(
5954
"--onnx_save_path", type=str, required=False, help="Path to save the final ONNX model."
@@ -62,7 +57,7 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
6257
"--batch_size",
6358
type=int,
6459
default=1,
65-
help="Batch size for the exported ViT model.",
60+
help="Batch size for the exported model.",
6661
)
6762
parser.add_argument(
6863
"--fp16",
@@ -71,54 +66,18 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
7166
)
7267
args = parser.parse_args()
7368

74-
if args.vit:
75-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76-
model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=1000).to(
77-
device
78-
)
79-
data_config = timm.data.resolve_model_data_config(model)
80-
input_shape = (args.batch_size,) + data_config["input_size"]
81-
82-
vit_save_path = args.onnx_save_path or "vit_base_patch16_224.onnx"
83-
weights_dtype = "fp16" if args.fp16 else "fp32"
84-
export_to_onnx(
85-
model,
86-
input_shape,
87-
vit_save_path,
88-
device,
89-
weights_dtype=weights_dtype,
90-
)
91-
print(f"ViT model exported to {vit_save_path}")
92-
93-
if args.llama:
94-
model_name = "meta-llama/Llama-3.1-8B-Instruct"
95-
if not args.onnx_save_path:
96-
args.onnx_save_path = "Llama-3.1-8B-Instruct/model.onnx"
97-
98-
output_dir = os.path.dirname(args.onnx_save_path)
99-
if not output_dir: # Handle cases where only filename is given (save in current dir)
100-
output_dir = "."
101-
os.makedirs(output_dir, exist_ok=True)
102-
103-
command = [
104-
"python",
105-
"-m",
106-
"optimum.commands.optimum_cli",
107-
"export",
108-
"onnx",
109-
"--model",
110-
model_name,
111-
"--task",
112-
"causal-lm-with-past",
113-
"--device",
114-
"cuda",
115-
"--fp16" if args.fp16 else "",
116-
output_dir,
117-
]
118-
119-
try:
120-
print(f"Running optimum-cli export to {output_dir}...")
121-
subprocess.run(command, check=True, capture_output=True, text=True, encoding="utf-8")
122-
print(f"Llama model exported to {output_dir}")
123-
except subprocess.CalledProcessError as e:
124-
raise RuntimeError(f"Failed to export model: {e.stderr}")
69+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70+
model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device)
71+
data_config = timm.data.resolve_model_data_config(model)
72+
input_shape = (args.batch_size,) + data_config["input_size"]
73+
74+
save_path = args.onnx_save_path or f"{args.timm_model_name}.onnx"
75+
weights_dtype = "fp16" if args.fp16 else "fp32"
76+
export_to_onnx(
77+
model,
78+
input_shape,
79+
save_path,
80+
device,
81+
weights_dtype=weights_dtype,
82+
)
83+
print(f"{args.timm_model_name} model exported to {save_path}")

examples/torch_onnx/README.md

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf
5353

5454
- Loads a pretrained timm torch model (default: ViT-Base).
5555
- Quantizes the torch model to FP8, MXFP8, INT8, NVFP4, or INT4_AWQ using ModelOpt.
56+
- For models with Conv2d layers (e.g., SwinTransformer), automatically overrides Conv2d quantization to FP8 (for MXFP8/NVFP4 modes) or INT8 (for INT4_AWQ mode) for TensorRT compatibility.
5657
- Exports the quantized model to ONNX.
5758
- Postprocesses the ONNX model to be compatible with TensorRT.
5859
- Saves the final ONNX model.
@@ -63,11 +64,21 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf
6364

6465
```bash
6566
python torch_quant_to_onnx.py \
66-
--timm_model_name=vit_base_patch16_224 \
67+
--timm_model_name=<timm model name> \
6768
--quantize_mode=<fp8|mxfp8|int8|nvfp4|int4_awq> \
6869
--onnx_save_path=<path to save the exported ONNX model>
6970
```
7071

72+
### Conv2d Quantization Override
73+
74+
TensorRT only supports FP8 and INT8 for convolution operations. When quantizing models with Conv2d layers (like SwinTransformer), the script automatically applies the following overrides:
75+
76+
| Quantize Mode | Conv2d Override | Reason |
77+
| :---: | :---: | :--- |
78+
| FP8, INT8 | None (already compatible) | Native TRT support |
79+
| MXFP8, NVFP4 | Conv2d -> FP8 | TRT Conv limitation |
80+
| INT4_AWQ | Conv2d -> INT8 | TRT Conv limitation |
81+
7182
### Evaluation
7283

7384
If the input model is of type image classification, use the following script to evaluate it. The script automatically downloads and uses the [ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset from Hugging Face. This gated repository requires authentication via Hugging Face access token. See <https://huggingface.co/docs/hub/en/security-tokens> for details.
@@ -79,7 +90,7 @@ python ../onnx_ptq/evaluate.py \
7990
--onnx_path=<path to the exported ONNX model> \
8091
--imagenet_path=<HF dataset card or local path to the ImageNet dataset> \
8192
--engine_precision=stronglyTyped \
82-
--model_name=vit_base_patch16_224
93+
--model_name=<timm model name>
8394
```
8495

8596
## LLM Quantization and Export with TensorRT-Edge-LLM
@@ -289,13 +300,13 @@ python torch_quant_to_onnx.py \
289300
--onnx_save_path=vit_base_patch16_224.auto_quant.onnx
290301
```
291302

292-
### Results (ViT-Base)
303+
## ONNX Export Supported Vision Models
293304

294-
| | Top-1 accuracy (torch) | Top-5 accuracy (torch) |
295-
| :--- | :---: | :---: |
296-
| Torch autocast (FP16) | 85.11% | 97.53% |
297-
| NVFP4 Quantized | 84.558% | 97.36% |
298-
| Auto Quantized (FP8 + NVFP4, 4.78 effective bits) | 84.726% | 97.434% |
305+
| Model | FP8 | INT8 | MXFP8 | NVFP4 | INT4_AWQ | Auto |
306+
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
307+
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | | |||| |
308+
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) || | ||| |
309+
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) ||||| | |
299310

300311
## Resources
301312

examples/torch_onnx/torch_quant_to_onnx.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
# limitations under the License.
1515

1616
import argparse
17+
import copy
18+
import json
1719
import re
1820
import sys
21+
import warnings
1922
from pathlib import Path
2023

2124
# Add onnx_ptq to path for shared modules
@@ -44,20 +47,69 @@
4447

4548
mp.set_start_method("spawn", force=True) # Needed for data loader with multiple workers
4649

47-
QUANT_CONFIG_DICT = {
50+
QUANT_CONFIG_DICT: dict[str, dict] = {
4851
"fp8": mtq.FP8_DEFAULT_CFG,
4952
"int8": mtq.INT8_DEFAULT_CFG,
5053
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
5154
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
5255
"int4_awq": mtq.INT4_AWQ_CFG,
5356
}
5457

58+
_FP8_CONV_OVERRIDE: list = [
59+
{
60+
"parent_class": "nn.Conv2d",
61+
"quantizer_name": "*weight_quantizer",
62+
"cfg": {"num_bits": (4, 3), "axis": None},
63+
},
64+
{
65+
"parent_class": "nn.Conv2d",
66+
"quantizer_name": "*input_quantizer",
67+
"cfg": {"num_bits": (4, 3), "axis": None},
68+
},
69+
]
70+
71+
_INT8_CONV_OVERRIDE: list = [
72+
{
73+
"parent_class": "nn.Conv2d",
74+
"quantizer_name": "*weight_quantizer",
75+
"cfg": {"num_bits": 8, "axis": 0},
76+
},
77+
{
78+
"parent_class": "nn.Conv2d",
79+
"quantizer_name": "*input_quantizer",
80+
"cfg": {"num_bits": 8, "axis": None},
81+
},
82+
]
83+
84+
85+
def get_quant_config(quantize_mode):
86+
"""Get quantization config, overriding Conv2d for TRT compatibility.
87+
88+
TensorRT only supports FP8 and INT8 for Conv layers.
89+
- For MXFP8, NVFP4: override Conv2d to FP8
90+
- For INT4_AWQ: override Conv2d to INT8
91+
"""
92+
config: dict = copy.deepcopy(QUANT_CONFIG_DICT[quantize_mode])
93+
if quantize_mode in ("mxfp8", "nvfp4"):
94+
warnings.warn(
95+
f"TensorRT only supports FP8/INT8 for Conv layers. "
96+
f"Overriding Conv2d quantization to FP8 for '{quantize_mode}' mode."
97+
)
98+
config["quant_cfg"].extend(_FP8_CONV_OVERRIDE)
99+
elif quantize_mode == "int4_awq":
100+
warnings.warn(
101+
"TensorRT only supports FP8/INT8 for Conv layers. "
102+
"Overriding Conv2d quantization to INT8 for 'int4_awq' mode."
103+
)
104+
config["quant_cfg"].extend(_INT8_CONV_OVERRIDE)
105+
return config
106+
55107

56108
def filter_func(name):
57109
"""Filter function to exclude certain layers from quantization."""
58110
pattern = re.compile(
59111
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|"
60-
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed).*"
112+
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*"
61113
)
62114
return pattern.match(name) is not None
63115

@@ -121,6 +173,18 @@ def loss_func(output, batch):
121173
return F.cross_entropy(output, batch["label"])
122174

123175

176+
def _disable_inplace_relu(model):
177+
"""Replace inplace ReLU with non-inplace ReLU throughout the model.
178+
179+
This is needed for auto_quantize which uses backward hooks for gradient-based
180+
sensitivity scoring. Inplace ReLU on views created by custom Functions causes
181+
PyTorch autograd errors.
182+
"""
183+
for module in model.modules():
184+
if isinstance(module, torch.nn.ReLU) and module.inplace:
185+
module.inplace = False
186+
187+
124188
def auto_quantize_model(
125189
model,
126190
data_loader,
@@ -142,6 +206,7 @@ def auto_quantize_model(
142206
Returns:
143207
Tuple of (quantized_model, search_state_dict)
144208
"""
209+
_disable_inplace_relu(model)
145210
constraints = {"effective_bits": effective_bits}
146211

147212
# Convert string format names to actual config objects
@@ -255,12 +320,29 @@ def main():
255320
default=128,
256321
help="Number of scoring steps for auto quantization. Default is 128.",
257322
)
323+
parser.add_argument(
324+
"--no_pretrained",
325+
action="store_true",
326+
help="Don't load pretrained weights (useful for testing with random weights).",
327+
)
328+
parser.add_argument(
329+
"--model_kwargs",
330+
type=str,
331+
default=None,
332+
help="JSON string of extra model kwargs (e.g., '{\"depth\": 1}').",
333+
)
258334

259335
args = parser.parse_args()
260336

261337
# Create model and move to appropriate device
262338
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
263-
model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device)
339+
model_kwargs = json.loads(args.model_kwargs) if args.model_kwargs else {}
340+
model = timm.create_model(
341+
args.timm_model_name,
342+
pretrained=not args.no_pretrained,
343+
num_classes=1000,
344+
**model_kwargs,
345+
).to(device)
264346

265347
# Get input shape from model config
266348
input_size = get_model_input_shape(model)
@@ -297,7 +379,7 @@ def main():
297379
)
298380
else:
299381
# Standard quantization - only load calibration data if needed
300-
config = QUANT_CONFIG_DICT[args.quantize_mode]
382+
config = get_quant_config(args.quantize_mode)
301383
if args.quantize_mode == "mxfp8":
302384
data_loader = None
303385
else:

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,9 +1162,12 @@ def cast_initializer_to_dtype(
11621162
node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto]
11631163
):
11641164
"""Casts the initializer to the given dtype."""
1165+
input_id = None
11651166
for id, input_name in enumerate(node.input):
11661167
if input_name in initializer_map:
11671168
input_id = id
1169+
if input_id is None:
1170+
return
11681171
input_name = node.input[input_id]
11691172
input = numpy_helper.to_array(initializer_map[input_name])
11701173
input = input.astype(np_dtype_map[dtype])

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,8 @@ def get_onnx_bytes_and_metadata(
608608
op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
609609
)
610610
# Change FP32 cast nodes feeding into Concat/Add to FP16
611-
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add", "Sqrt"])
611+
op_list = ["Concat", "Add", "Sqrt", "LayerNormalization", "Clip", "Mul", "Exp"]
612+
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, op_list)
612613
else:
613614
onnx_opt_graph = convert_to_f16(
614615
onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False

tests/_test_utils/torch/vision_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def forward(self, x):
6969
def _create_timm_fn(name):
7070
def get_model_and_input(on_gpu: bool = False):
7171
model = timm.create_model(name)
72-
return process_model_and_inputs(model, (torch.randn(1, 3, 224, 224),), {}, on_gpu)
72+
data_config = timm.data.resolve_model_data_config(model)
73+
input_size = data_config["input_size"] # e.g., (3, 224, 224)
74+
return process_model_and_inputs(model, (torch.randn(1, *input_size),), {}, on_gpu)
7375

7476
return get_model_and_input
7577

@@ -114,6 +116,7 @@ def get_model_and_input(on_gpu: bool = False):
114116
# "vovnet39a",
115117
# "dm_nfnet_f0",
116118
"efficientnet_b0",
119+
"swin_tiny_patch4_window7_224",
117120
],
118121
_create_timm_fn,
119122
),

0 commit comments

Comments
 (0)