You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add SwinTransformer support for torch_onnx quantization workflow (#1235)
## Summary
- Enable end-to-end quantize → ONNX export → TRT engine pipeline for
SwinTransformer models (v1 and v2) across FP8, INT8, MXFP8, NVFP4, and
auto precision modes
- Add Conv2d quantization overrides for TRT compatibility (TRT only
supports FP8/INT8 for convolutions)
- Fix FP8 LayerNorm type mismatch in TRT stronglyTyped mode by adding
`LayerNormalization` to `change_casts_to_fp16`
- Fix `cast_initializer_to_dtype` crash when node has no initializer
inputs
- Simplify `download_example_onnx.py` to a single `--timm_model_name`
(required) flag, removing redundant `--vit` and `--llama` flags
- Add vision model support matrix to README (ViT, Swin, SwinV2)
- Rewrite tests: parametrize over (ViT, Swin, SwinV2) × (fp8, int8,
mxfp8, nvfp4, auto) with TRT engine build verification
## Test plan
- [ ] `python -m pytest
tests/examples/torch_onnx/test_torch_quant_to_onnx.py -v` — 15 tests (3
models × 5 modes), all pass
- [ ] Verified Swin accuracy on ImageNet-1k across all precisions (FP8:
81.29%, INT8: 81.12%, MXFP8: 81.32%, NVFP4: 80.79%, Auto: 80.84% TRT
top-1 vs 81.37% base)
- [ ] INT4_AWQ deferred (TODO in test file) — requires INT4 exporter
changes for non-MatMul/Gemm consumer patterns
🤖 Generated with [Claude Code](https://claude.com/claude-code)
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* ONNX export supports arbitrary timm vision models with auto device
selection and new CLI options (--timm_model_name, --model_kwargs,
--no_pretrained); batch-size/input sizing is now model-generic.
* **Bug Fixes**
* Expanded FP16/BF16 cast handling to additional ONNX ops.
* Disabled inplace ReLU before auto-quantization to avoid incorrect
transforms.
* Conv2d quantization overrides added for improved TensorRT
compatibility.
* Safer handling when initializers are missing during dtype casting.
* **Documentation**
* README updated with supported models table, quantization mappings, and
example CLI usage.
* **Tests**
* Tests expanded to multiple architectures/quant modes and now verify
TensorRT engine build.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy file name to clipboardExpand all lines: examples/torch_onnx/README.md
+19-8Lines changed: 19 additions & 8 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -53,6 +53,7 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf
53
53
54
54
- Loads a pretrained timm torch model (default: ViT-Base).
55
55
- Quantizes the torch model to FP8, MXFP8, INT8, NVFP4, or INT4_AWQ using ModelOpt.
56
+
- For models with Conv2d layers (e.g., SwinTransformer), automatically overrides Conv2d quantization to FP8 (for MXFP8/NVFP4 modes) or INT8 (for INT4_AWQ mode) for TensorRT compatibility.
56
57
- Exports the quantized model to ONNX.
57
58
- Postprocesses the ONNX model to be compatible with TensorRT.
58
59
- Saves the final ONNX model.
@@ -63,11 +64,21 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf
63
64
64
65
```bash
65
66
python torch_quant_to_onnx.py \
66
-
--timm_model_name=vit_base_patch16_224 \
67
+
--timm_model_name=<timm model name> \
67
68
--quantize_mode=<fp8|mxfp8|int8|nvfp4|int4_awq> \
68
69
--onnx_save_path=<path to save the exported ONNX model>
69
70
```
70
71
72
+
### Conv2d Quantization Override
73
+
74
+
TensorRT only supports FP8 and INT8 for convolution operations. When quantizing models with Conv2d layers (like SwinTransformer), the script automatically applies the following overrides:
75
+
76
+
| Quantize Mode | Conv2d Override | Reason |
77
+
| :---: | :---: | :--- |
78
+
| FP8, INT8 | None (already compatible) | Native TRT support |
If the input model is of type image classification, use the following script to evaluate it. The script automatically downloads and uses the [ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset from Hugging Face. This gated repository requires authentication via Hugging Face access token. See <https://huggingface.co/docs/hub/en/security-tokens> for details.
0 commit comments