-
Notifications
You must be signed in to change notification settings - Fork 400
[OMNIML-3349] Add SwinTransformer support for torch_onnx quantization workflow #1235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -53,6 +53,7 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| - Loads a pretrained timm torch model (default: ViT-Base). | ||||||||||||||||||||||
| - Quantizes the torch model to FP8, MXFP8, INT8, NVFP4, or INT4_AWQ using ModelOpt. | ||||||||||||||||||||||
| - For models with Conv2d layers (e.g., SwinTransformer), automatically overrides Conv2d quantization to FP8 (for MXFP8/NVFP4 modes) or INT8 (for INT4_AWQ mode) for TensorRT compatibility. | ||||||||||||||||||||||
| - Exports the quantized model to ONNX. | ||||||||||||||||||||||
| - Postprocesses the ONNX model to be compatible with TensorRT. | ||||||||||||||||||||||
| - Saves the final ONNX model. | ||||||||||||||||||||||
|
|
@@ -63,11 +64,21 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| ```bash | ||||||||||||||||||||||
| python torch_quant_to_onnx.py \ | ||||||||||||||||||||||
| --timm_model_name=vit_base_patch16_224 \ | ||||||||||||||||||||||
| --timm_model_name=<timm model name> \ | ||||||||||||||||||||||
| --quantize_mode=<fp8|mxfp8|int8|nvfp4|int4_awq> \ | ||||||||||||||||||||||
| --onnx_save_path=<path to save the exported ONNX model> | ||||||||||||||||||||||
| ``` | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| ### Conv2d Quantization Override | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| TensorRT only supports FP8 and INT8 for convolution operations. When quantizing models with Conv2d layers (like SwinTransformer), the script automatically applies the following overrides: | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| | Quantize Mode | Conv2d Override | Reason | | ||||||||||||||||||||||
| | :---: | :---: | :--- | | ||||||||||||||||||||||
| | FP8, INT8 | None (already compatible) | Native TRT support | | ||||||||||||||||||||||
| | MXFP8, NVFP4 | Conv2d -> FP8 | TRT Conv limitation | | ||||||||||||||||||||||
| | INT4_AWQ | Conv2d -> INT8 | TRT Conv limitation | | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| ### Evaluation | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| If the input model is of type image classification, use the following script to evaluate it. The script automatically downloads and uses the [ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset from Hugging Face. This gated repository requires authentication via Hugging Face access token. See <https://huggingface.co/docs/hub/en/security-tokens> for details. | ||||||||||||||||||||||
|
|
@@ -79,7 +90,7 @@ python ../onnx_ptq/evaluate.py \ | |||||||||||||||||||||
| --onnx_path=<path to the exported ONNX model> \ | ||||||||||||||||||||||
| --imagenet_path=<HF dataset card or local path to the ImageNet dataset> \ | ||||||||||||||||||||||
| --engine_precision=stronglyTyped \ | ||||||||||||||||||||||
| --model_name=vit_base_patch16_224 | ||||||||||||||||||||||
| --model_name=<timm model name> | ||||||||||||||||||||||
| ``` | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| ## LLM Quantization and Export with TensorRT-Edge-LLM | ||||||||||||||||||||||
|
|
@@ -289,13 +300,13 @@ python torch_quant_to_onnx.py \ | |||||||||||||||||||||
| --onnx_save_path=vit_base_patch16_224.auto_quant.onnx | ||||||||||||||||||||||
| ``` | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| ### Results (ViT-Base) | ||||||||||||||||||||||
| ## ONNX Export Supported Vision Models | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| | | Top-1 accuracy (torch) | Top-5 accuracy (torch) | | ||||||||||||||||||||||
| | :--- | :---: | :---: | | ||||||||||||||||||||||
| | Torch autocast (FP16) | 85.11% | 97.53% | | ||||||||||||||||||||||
| | NVFP4 Quantized | 84.558% | 97.36% | | ||||||||||||||||||||||
| | Auto Quantized (FP8 + NVFP4, 4.78 effective bits) | 84.726% | 97.434% | | ||||||||||||||||||||||
| | Model | FP8 | INT8 | MXFP8 | NVFP4 | INT4_AWQ | Auto | | ||||||||||||||||||||||
| | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | ||||||||||||||||||||||
| | [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ||||||||||||||||||||||
| | [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ||||||||||||||||||||||
| | [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ||||||||||||||||||||||
|
Comment on lines
+305
to
+309
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Support matrix overstates
📝 Suggested docs correction-| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| ## Resources | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,28 +14,67 @@ | |
| # limitations under the License. | ||
|
|
||
|
|
||
| import os | ||
| import subprocess | ||
|
|
||
| import pytest | ||
| from _test_utils.examples.run_command import extend_cmd_parts, run_example_command | ||
|
|
||
| # TODO: Add int4_awq once the INT4 exporter supports non-MatMul/Gemm consumer patterns | ||
| # (e.g., DQ -> Reshape -> Slice in small ViT / SwinTransformer ONNX graphs). | ||
| _QUANT_MODES = ["fp8", "int8", "mxfp8", "nvfp4", "auto"] | ||
|
|
||
| _MODELS = { | ||
| "vit_tiny": ("vit_tiny_patch16_224", '{"depth": 1}'), | ||
| "swin_tiny": ("swin_tiny_patch4_window7_224", '{"depths": [1, 1, 1, 1]}'), | ||
| "swinv2_tiny": ("swinv2_tiny_window8_256", '{"depths": [1, 1, 1, 1]}'), | ||
| } | ||
|
|
||
| # Builder optimization level: 4 for low-bit modes, 3 otherwise | ||
| _LOW_BIT_MODES = {"fp8", "int8", "nvfp4"} | ||
|
|
||
|
|
||
| def _verify_trt_engine_build(onnx_save_path, quantize_mode): | ||
| """Verify the exported ONNX model can be compiled into a TensorRT engine.""" | ||
| example_dir = os.path.join( | ||
| os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx" | ||
| ) | ||
| onnx_path = os.path.join(example_dir, onnx_save_path) | ||
| assert os.path.exists(onnx_path), f"ONNX file not found: {onnx_path}" | ||
|
|
||
| opt_level = "4" if quantize_mode in _LOW_BIT_MODES else "3" | ||
| cmd = [ | ||
| "trtexec", | ||
| f"--onnx={onnx_path}", | ||
| "--stronglyTyped", | ||
| f"--builderOptimizationLevel={opt_level}", | ||
| ] | ||
|
|
||
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) # nosec | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||
| assert result.returncode == 0, ( | ||
| f"TensorRT engine build failed for {onnx_save_path} " | ||
| f"(mode={quantize_mode}):\n{result.stdout}\n{result.stderr}" | ||
| ) | ||
|
Comment on lines
+37
to
+57
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Verify whether a skip guard for trtexec exists in this test file.
rg -n 'def _verify_trt_engine_build|subprocess\.run|shutil\.which\("trtexec"\)|pytest\.skip' tests/examples/torch_onnx/test_torch_quant_to_onnx.pyRepository: NVIDIA/Model-Optimizer Length of output: 210 🏁 Script executed: #!/bin/bash
# Check the full test file structure to understand the test matrix and any skip conditions
head -80 tests/examples/torch_onnx/test_torch_quant_to_onnx.py | cat -nRepository: NVIDIA/Model-Optimizer Length of output: 3684 🏁 Script executed: #!/bin/bash
# Search for pytest decorators, fixtures, and skip conditions across the entire file
rg -n '@pytest\.|@skip|pytest\.mark|skipif|@fixture|parametrize' tests/examples/torch_onnx/test_torch_quant_to_onnx.pyRepository: NVIDIA/Model-Optimizer Length of output: 180 🏁 Script executed: #!/bin/bash
# Check if trtexec availability is checked anywhere at module level or in setup
rg -n 'trtexec|shutil\.which|which.*trtexec|importorskip' tests/examples/torch_onnx/test_torch_quant_to_onnx.pyRepository: NVIDIA/Model-Optimizer Length of output: 87 Guard The function Add an availability check at the start of Suggested fix+import shutil
import subprocess
@@
def _verify_trt_engine_build(onnx_save_path, quantize_mode):
"""Verify the exported ONNX model can be compiled into a TensorRT engine."""
+ if shutil.which("trtexec") is None:
+ pytest.skip("Skipping: `trtexec` is not available in PATH.")
+
example_dir = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx"
)🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| @pytest.mark.parametrize("quantize_mode", _QUANT_MODES) | ||
| @pytest.mark.parametrize("model_key", list(_MODELS)) | ||
| def test_torch_onnx(model_key, quantize_mode): | ||
| timm_model_name, model_kwargs = _MODELS[model_key] | ||
| onnx_save_path = f"{model_key}.{quantize_mode}.onnx" | ||
|
|
||
| # TODO: Add accuracy evaluation after we upgrade TRT version to 10.12 | ||
| @pytest.mark.parametrize( | ||
| ("quantize_mode", "onnx_save_path", "calib_size", "num_score_steps"), | ||
| [ | ||
| ("fp8", "vit_base_patch16_224.fp8.onnx", "1", "1"), | ||
| ("int8", "vit_base_patch16_224.int8.onnx", "1", "1"), | ||
| ("nvfp4", "vit_base_patch16_224.nvfp4.onnx", "1", "1"), | ||
| ("mxfp8", "vit_base_patch16_224.mxfp8.onnx", "1", "1"), | ||
| ("int4_awq", "vit_base_patch16_224.int4_awq.onnx", "1", "1"), | ||
| ("auto", "vit_base_patch16_224.auto.onnx", "1", "1"), | ||
| ], | ||
| ) | ||
| def test_torch_onnx(quantize_mode, onnx_save_path, calib_size, num_score_steps): | ||
| # Step 1: Quantize and export to ONNX | ||
| cmd_parts = extend_cmd_parts( | ||
| ["python", "torch_quant_to_onnx.py"], | ||
| timm_model_name=timm_model_name, | ||
| model_kwargs=model_kwargs, | ||
| quantize_mode=quantize_mode, | ||
| onnx_save_path=onnx_save_path, | ||
| calibration_data_size=calib_size, | ||
| num_score_steps=num_score_steps, | ||
| calibration_data_size="1", | ||
| num_score_steps="1", | ||
| ) | ||
| cmd_parts.append("--no_pretrained") | ||
| run_example_command(cmd_parts, "torch_onnx") | ||
|
|
||
| # Step 2: Verify the exported ONNX model builds a TensorRT engine | ||
| _verify_trt_engine_build(onnx_save_path, quantize_mode) | ||
Uh oh!
There was an error while loading. Please reload this page.