Skip to content

Commit eaffcec

Browse files
ajrasaneclaude
andcommitted
Add SwinTransformer support for torch_onnx quantization workflow
Enable end-to-end quantize-export-TRT pipeline for SwinTransformer models (v1 and v2) across FP8, INT8, MXFP8, NVFP4, and auto precision modes. Core fixes: - Add LayerNormalization, Clip, Mul, Exp to change_casts_to_fp16 for FP8 stronglyTyped compatibility (fixes type mismatches in Swin/SwinV2 TRT builds) Example/test changes: - Add Conv2d quantization overrides for TRT compatibility (MXFP8/NVFP4->FP8, INT4_AWQ->INT8) since TRT only supports FP8/INT8 for convolutions - Add cpb_mlp and downsample to quantization filter exclusion list - Add --no_pretrained and --model_kwargs CLI args for testing with tiny models - Add --timm_model_name to download_example_onnx.py (default: ViT) - Add SwinTransformer to vision_models.py with dynamic input size resolution - Rewrite tests: parametrize over (ViT, Swin, SwinV2) x (fp8, int8, mxfp8, nvfp4, auto) with TRT engine build verification using --stronglyTyped - Update README with vision model support matrix and Conv2d override docs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 5523505 commit eaffcec

File tree

6 files changed

+191
-30
lines changed

6 files changed

+191
-30
lines changed

examples/onnx_ptq/download_example_onnx.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
5050
action="store_true",
5151
help="Export timm/vit_base_patch16_224 model to ONNX.",
5252
)
53+
parser.add_argument(
54+
"--timm_model_name",
55+
type=str,
56+
default="vit_base_patch16_224",
57+
help="Export any timm model to ONNX (e.g., swin_tiny_patch4_window7_224).",
58+
)
5359
parser.add_argument(
5460
"--llama",
5561
action="store_true",
@@ -62,7 +68,7 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
6268
"--batch_size",
6369
type=int,
6470
default=1,
65-
help="Batch size for the exported ViT model.",
71+
help="Batch size for the exported model.",
6672
)
6773
parser.add_argument(
6874
"--fp16",
@@ -90,6 +96,25 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
9096
)
9197
print(f"ViT model exported to {vit_save_path}")
9298

99+
if args.timm_model_name:
100+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101+
model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(
102+
device
103+
)
104+
data_config = timm.data.resolve_model_data_config(model)
105+
input_shape = (args.batch_size,) + data_config["input_size"]
106+
107+
save_path = args.onnx_save_path or f"{args.timm_model_name}.onnx"
108+
weights_dtype = "fp16" if args.fp16 else "fp32"
109+
export_to_onnx(
110+
model,
111+
input_shape,
112+
save_path,
113+
device,
114+
weights_dtype=weights_dtype,
115+
)
116+
print(f"{args.timm_model_name} model exported to {save_path}")
117+
93118
if args.llama:
94119
model_name = "meta-llama/Llama-3.1-8B-Instruct"
95120
if not args.onnx_save_path:

examples/torch_onnx/README.md

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf
5353

5454
- Loads a pretrained timm torch model (default: ViT-Base).
5555
- Quantizes the torch model to FP8, MXFP8, INT8, NVFP4, or INT4_AWQ using ModelOpt.
56+
- For models with Conv2d layers (e.g., SwinTransformer), automatically overrides Conv2d quantization to FP8 (for MXFP8/NVFP4 modes) or INT8 (for INT4_AWQ mode) for TensorRT compatibility.
5657
- Exports the quantized model to ONNX.
5758
- Postprocesses the ONNX model to be compatible with TensorRT.
5859
- Saves the final ONNX model.
@@ -63,11 +64,21 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf
6364

6465
```bash
6566
python torch_quant_to_onnx.py \
66-
--timm_model_name=vit_base_patch16_224 \
67+
--timm_model_name=<timm model name> \
6768
--quantize_mode=<fp8|mxfp8|int8|nvfp4|int4_awq> \
6869
--onnx_save_path=<path to save the exported ONNX model>
6970
```
7071

72+
### Conv2d Quantization Override
73+
74+
TensorRT only supports FP8 and INT8 for convolution operations. When quantizing models with Conv2d layers (like SwinTransformer), the script automatically applies the following overrides:
75+
76+
| Quantize Mode | Conv2d Override | Reason |
77+
| :---: | :---: | :--- |
78+
| FP8, INT8 | None (already compatible) | Native TRT support |
79+
| MXFP8, NVFP4 | Conv2d -> FP8 | TRT Conv limitation |
80+
| INT4_AWQ | Conv2d -> INT8 | TRT Conv limitation |
81+
7182
### Evaluation
7283

7384
If the input model is of type image classification, use the following script to evaluate it. The script automatically downloads and uses the [ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset from Hugging Face. This gated repository requires authentication via Hugging Face access token. See <https://huggingface.co/docs/hub/en/security-tokens> for details.
@@ -79,7 +90,7 @@ python ../onnx_ptq/evaluate.py \
7990
--onnx_path=<path to the exported ONNX model> \
8091
--imagenet_path=<HF dataset card or local path to the ImageNet dataset> \
8192
--engine_precision=stronglyTyped \
82-
--model_name=vit_base_patch16_224
93+
--model_name=<timm model name>
8394
```
8495

8596
## LLM Quantization and Export with TensorRT-Edge-LLM
@@ -289,13 +300,13 @@ python torch_quant_to_onnx.py \
289300
--onnx_save_path=vit_base_patch16_224.auto_quant.onnx
290301
```
291302

292-
### Results (ViT-Base)
303+
## ONNX Export Supported Vision Models
293304

294-
| | Top-1 accuracy (torch) | Top-5 accuracy (torch) |
295-
| :--- | :---: | :---: |
296-
| Torch autocast (FP16) | 85.11% | 97.53% |
297-
| NVFP4 Quantized | 84.558% | 97.36% |
298-
| Auto Quantized (FP8 + NVFP4, 4.78 effective bits) | 84.726% | 97.434% |
305+
| Model | FP8 | INT8 | MXFP8 | NVFP4 | INT4_AWQ | Auto |
306+
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
307+
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | | |||| |
308+
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) || | ||| |
309+
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) ||||| | |
299310

300311
## Resources
301312

examples/torch_onnx/torch_quant_to_onnx.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
# limitations under the License.
1515

1616
import argparse
17+
import copy
18+
import json
1719
import re
1820
import sys
21+
import warnings
1922
from pathlib import Path
2023

2124
# Add onnx_ptq to path for shared modules
@@ -44,20 +47,69 @@
4447

4548
mp.set_start_method("spawn", force=True) # Needed for data loader with multiple workers
4649

47-
QUANT_CONFIG_DICT = {
50+
QUANT_CONFIG_DICT: dict[str, dict] = {
4851
"fp8": mtq.FP8_DEFAULT_CFG,
4952
"int8": mtq.INT8_DEFAULT_CFG,
5053
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
5154
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
5255
"int4_awq": mtq.INT4_AWQ_CFG,
5356
}
5457

58+
_FP8_CONV_OVERRIDE: list = [
59+
{
60+
"parent_class": "nn.Conv2d",
61+
"quantizer_name": "*weight_quantizer",
62+
"cfg": {"num_bits": (4, 3), "axis": None},
63+
},
64+
{
65+
"parent_class": "nn.Conv2d",
66+
"quantizer_name": "*input_quantizer",
67+
"cfg": {"num_bits": (4, 3), "axis": None},
68+
},
69+
]
70+
71+
_INT8_CONV_OVERRIDE: list = [
72+
{
73+
"parent_class": "nn.Conv2d",
74+
"quantizer_name": "*weight_quantizer",
75+
"cfg": {"num_bits": 8, "axis": 0},
76+
},
77+
{
78+
"parent_class": "nn.Conv2d",
79+
"quantizer_name": "*input_quantizer",
80+
"cfg": {"num_bits": 8, "axis": None},
81+
},
82+
]
83+
84+
85+
def get_quant_config(quantize_mode):
86+
"""Get quantization config, overriding Conv2d for TRT compatibility.
87+
88+
TensorRT only supports FP8 and INT8 for Conv layers.
89+
- For MXFP8, NVFP4: override Conv2d to FP8
90+
- For INT4_AWQ: override Conv2d to INT8
91+
"""
92+
config: dict = copy.deepcopy(QUANT_CONFIG_DICT[quantize_mode])
93+
if quantize_mode in ("mxfp8", "nvfp4"):
94+
warnings.warn(
95+
f"TensorRT only supports FP8/INT8 for Conv layers. "
96+
f"Overriding Conv2d quantization to FP8 for '{quantize_mode}' mode."
97+
)
98+
config["quant_cfg"].extend(_FP8_CONV_OVERRIDE)
99+
elif quantize_mode == "int4_awq":
100+
warnings.warn(
101+
"TensorRT only supports FP8/INT8 for Conv layers. "
102+
"Overriding Conv2d quantization to INT8 for 'int4_awq' mode."
103+
)
104+
config["quant_cfg"].extend(_INT8_CONV_OVERRIDE)
105+
return config
106+
55107

56108
def filter_func(name):
57109
"""Filter function to exclude certain layers from quantization."""
58110
pattern = re.compile(
59111
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|"
60-
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed).*"
112+
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*"
61113
)
62114
return pattern.match(name) is not None
63115

@@ -121,6 +173,18 @@ def loss_func(output, batch):
121173
return F.cross_entropy(output, batch["label"])
122174

123175

176+
def _disable_inplace_relu(model):
177+
"""Replace inplace ReLU with non-inplace ReLU throughout the model.
178+
179+
This is needed for auto_quantize which uses backward hooks for gradient-based
180+
sensitivity scoring. Inplace ReLU on views created by custom Functions causes
181+
PyTorch autograd errors.
182+
"""
183+
for module in model.modules():
184+
if isinstance(module, torch.nn.ReLU) and module.inplace:
185+
module.inplace = False
186+
187+
124188
def auto_quantize_model(
125189
model,
126190
data_loader,
@@ -142,6 +206,7 @@ def auto_quantize_model(
142206
Returns:
143207
Tuple of (quantized_model, search_state_dict)
144208
"""
209+
_disable_inplace_relu(model)
145210
constraints = {"effective_bits": effective_bits}
146211

147212
# Convert string format names to actual config objects
@@ -255,12 +320,29 @@ def main():
255320
default=128,
256321
help="Number of scoring steps for auto quantization. Default is 128.",
257322
)
323+
parser.add_argument(
324+
"--no_pretrained",
325+
action="store_true",
326+
help="Don't load pretrained weights (useful for testing with random weights).",
327+
)
328+
parser.add_argument(
329+
"--model_kwargs",
330+
type=str,
331+
default=None,
332+
help="JSON string of extra model kwargs (e.g., '{\"depth\": 1}').",
333+
)
258334

259335
args = parser.parse_args()
260336

261337
# Create model and move to appropriate device
262338
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
263-
model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device)
339+
model_kwargs = json.loads(args.model_kwargs) if args.model_kwargs else {}
340+
model = timm.create_model(
341+
args.timm_model_name,
342+
pretrained=not args.no_pretrained,
343+
num_classes=1000,
344+
**model_kwargs,
345+
).to(device)
264346

265347
# Get input shape from model config
266348
input_size = get_model_input_shape(model)
@@ -297,7 +379,7 @@ def main():
297379
)
298380
else:
299381
# Standard quantization - only load calibration data if needed
300-
config = QUANT_CONFIG_DICT[args.quantize_mode]
382+
config = get_quant_config(args.quantize_mode)
301383
if args.quantize_mode == "mxfp8":
302384
data_loader = None
303385
else:

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,8 @@ def get_onnx_bytes_and_metadata(
608608
op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
609609
)
610610
# Change FP32 cast nodes feeding into Concat/Add to FP16
611-
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add", "Sqrt"])
611+
op_list = ["Concat", "Add", "Sqrt", "LayerNormalization", "Clip", "Mul", "Exp"]
612+
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, op_list)
612613
else:
613614
onnx_opt_graph = convert_to_f16(
614615
onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False

tests/_test_utils/torch/vision_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def forward(self, x):
6969
def _create_timm_fn(name):
7070
def get_model_and_input(on_gpu: bool = False):
7171
model = timm.create_model(name)
72-
return process_model_and_inputs(model, (torch.randn(1, 3, 224, 224),), {}, on_gpu)
72+
data_config = timm.data.resolve_model_data_config(model)
73+
input_size = data_config["input_size"] # e.g., (3, 224, 224)
74+
return process_model_and_inputs(model, (torch.randn(1, *input_size),), {}, on_gpu)
7375

7476
return get_model_and_input
7577

@@ -114,6 +116,7 @@ def get_model_and_input(on_gpu: bool = False):
114116
# "vovnet39a",
115117
# "dm_nfnet_f0",
116118
"efficientnet_b0",
119+
"swin_tiny_patch4_window7_224",
117120
],
118121
_create_timm_fn,
119122
),

tests/examples/torch_onnx/test_torch_quant_to_onnx.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,67 @@
1414
# limitations under the License.
1515

1616

17+
import os
18+
import subprocess
19+
1720
import pytest
1821
from _test_utils.examples.run_command import extend_cmd_parts, run_example_command
1922

23+
# TODO: Add int4_awq once the INT4 exporter supports non-MatMul/Gemm consumer patterns
24+
# (e.g., DQ -> Reshape -> Slice in small ViT / SwinTransformer ONNX graphs).
25+
_QUANT_MODES = ["fp8", "int8", "mxfp8", "nvfp4", "auto"]
26+
27+
_MODELS = {
28+
"vit_tiny": ("vit_tiny_patch16_224", '{"depth": 1}'),
29+
"swin_tiny": ("swin_tiny_patch4_window7_224", '{"depths": [1, 1, 1, 1]}'),
30+
"swinv2_tiny": ("swinv2_tiny_window8_256", '{"depths": [1, 1, 1, 1]}'),
31+
}
32+
33+
# Builder optimization level: 4 for low-bit modes, 3 otherwise
34+
_LOW_BIT_MODES = {"fp8", "int8", "nvfp4"}
35+
36+
37+
def _verify_trt_engine_build(onnx_save_path, quantize_mode):
38+
"""Verify the exported ONNX model can be compiled into a TensorRT engine."""
39+
example_dir = os.path.join(
40+
os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx"
41+
)
42+
onnx_path = os.path.join(example_dir, onnx_save_path)
43+
assert os.path.exists(onnx_path), f"ONNX file not found: {onnx_path}"
44+
45+
opt_level = "4" if quantize_mode in _LOW_BIT_MODES else "3"
46+
cmd = [
47+
"trtexec",
48+
f"--onnx={onnx_path}",
49+
"--stronglyTyped",
50+
f"--builderOptimizationLevel={opt_level}",
51+
]
52+
53+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) # nosec
54+
assert result.returncode == 0, (
55+
f"TensorRT engine build failed for {onnx_save_path} "
56+
f"(mode={quantize_mode}):\n{result.stdout}\n{result.stderr}"
57+
)
58+
59+
60+
@pytest.mark.parametrize("quantize_mode", _QUANT_MODES)
61+
@pytest.mark.parametrize("model_key", list(_MODELS))
62+
def test_torch_onnx(model_key, quantize_mode):
63+
timm_model_name, model_kwargs = _MODELS[model_key]
64+
onnx_save_path = f"{model_key}.{quantize_mode}.onnx"
2065

21-
# TODO: Add accuracy evaluation after we upgrade TRT version to 10.12
22-
@pytest.mark.parametrize(
23-
("quantize_mode", "onnx_save_path", "calib_size", "num_score_steps"),
24-
[
25-
("fp8", "vit_base_patch16_224.fp8.onnx", "1", "1"),
26-
("int8", "vit_base_patch16_224.int8.onnx", "1", "1"),
27-
("nvfp4", "vit_base_patch16_224.nvfp4.onnx", "1", "1"),
28-
("mxfp8", "vit_base_patch16_224.mxfp8.onnx", "1", "1"),
29-
("int4_awq", "vit_base_patch16_224.int4_awq.onnx", "1", "1"),
30-
("auto", "vit_base_patch16_224.auto.onnx", "1", "1"),
31-
],
32-
)
33-
def test_torch_onnx(quantize_mode, onnx_save_path, calib_size, num_score_steps):
66+
# Step 1: Quantize and export to ONNX
3467
cmd_parts = extend_cmd_parts(
3568
["python", "torch_quant_to_onnx.py"],
69+
timm_model_name=timm_model_name,
70+
model_kwargs=model_kwargs,
3671
quantize_mode=quantize_mode,
3772
onnx_save_path=onnx_save_path,
38-
calibration_data_size=calib_size,
39-
num_score_steps=num_score_steps,
73+
calibration_data_size="1",
74+
num_score_steps="1",
4075
)
76+
cmd_parts.append("--no_pretrained")
4177
run_example_command(cmd_parts, "torch_onnx")
78+
79+
# Step 2: Verify the exported ONNX model builds a TensorRT engine
80+
_verify_trt_engine_build(onnx_save_path, quantize_mode)

0 commit comments

Comments
 (0)