Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion examples/onnx_ptq/download_example_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
action="store_true",
help="Export timm/vit_base_patch16_224 model to ONNX.",
)
parser.add_argument(
"--timm_model_name",
type=str,
default="vit_base_patch16_224",
help="Export any timm model to ONNX (e.g., swin_tiny_patch4_window7_224).",
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
parser.add_argument(
"--llama",
action="store_true",
Expand All @@ -62,7 +68,7 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
"--batch_size",
type=int,
default=1,
help="Batch size for the exported ViT model.",
help="Batch size for the exported model.",
)
parser.add_argument(
"--fp16",
Expand Down Expand Up @@ -90,6 +96,25 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
)
print(f"ViT model exported to {vit_save_path}")

if args.timm_model_name:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(
device
)
data_config = timm.data.resolve_model_data_config(model)
input_shape = (args.batch_size,) + data_config["input_size"]

save_path = args.onnx_save_path or f"{args.timm_model_name}.onnx"
weights_dtype = "fp16" if args.fp16 else "fp32"
export_to_onnx(
model,
input_shape,
save_path,
device,
weights_dtype=weights_dtype,
)
print(f"{args.timm_model_name} model exported to {save_path}")

if args.llama:
model_name = "meta-llama/Llama-3.1-8B-Instruct"
if not args.onnx_save_path:
Expand Down
27 changes: 19 additions & 8 deletions examples/torch_onnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf

- Loads a pretrained timm torch model (default: ViT-Base).
- Quantizes the torch model to FP8, MXFP8, INT8, NVFP4, or INT4_AWQ using ModelOpt.
- For models with Conv2d layers (e.g., SwinTransformer), automatically overrides Conv2d quantization to FP8 (for MXFP8/NVFP4 modes) or INT8 (for INT4_AWQ mode) for TensorRT compatibility.
- Exports the quantized model to ONNX.
- Postprocesses the ONNX model to be compatible with TensorRT.
- Saves the final ONNX model.
Expand All @@ -63,11 +64,21 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf

```bash
python torch_quant_to_onnx.py \
--timm_model_name=vit_base_patch16_224 \
--timm_model_name=<timm model name> \
--quantize_mode=<fp8|mxfp8|int8|nvfp4|int4_awq> \
--onnx_save_path=<path to save the exported ONNX model>
```

### Conv2d Quantization Override

TensorRT only supports FP8 and INT8 for convolution operations. When quantizing models with Conv2d layers (like SwinTransformer), the script automatically applies the following overrides:

| Quantize Mode | Conv2d Override | Reason |
| :---: | :---: | :--- |
| FP8, INT8 | None (already compatible) | Native TRT support |
| MXFP8, NVFP4 | Conv2d -> FP8 | TRT Conv limitation |
| INT4_AWQ | Conv2d -> INT8 | TRT Conv limitation |

### Evaluation

If the input model is of type image classification, use the following script to evaluate it. The script automatically downloads and uses the [ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset from Hugging Face. This gated repository requires authentication via Hugging Face access token. See <https://huggingface.co/docs/hub/en/security-tokens> for details.
Expand All @@ -79,7 +90,7 @@ python ../onnx_ptq/evaluate.py \
--onnx_path=<path to the exported ONNX model> \
--imagenet_path=<HF dataset card or local path to the ImageNet dataset> \
--engine_precision=stronglyTyped \
--model_name=vit_base_patch16_224
--model_name=<timm model name>
```

## LLM Quantization and Export with TensorRT-Edge-LLM
Expand Down Expand Up @@ -289,13 +300,13 @@ python torch_quant_to_onnx.py \
--onnx_save_path=vit_base_patch16_224.auto_quant.onnx
```

### Results (ViT-Base)
## ONNX Export Supported Vision Models

| | Top-1 accuracy (torch) | Top-5 accuracy (torch) |
| :--- | :---: | :---: |
| Torch autocast (FP16) | 85.11% | 97.53% |
| NVFP4 Quantized | 84.558% | 97.36% |
| Auto Quantized (FP8 + NVFP4, 4.78 effective bits) | 84.726% | 97.434% |
| Model | FP8 | INT8 | MXFP8 | NVFP4 | INT4_AWQ | Auto |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | | ✅ | ✅ | ✅ | ✅ | ✅ |
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | |
Comment on lines +305 to +309
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Support matrix overstates swinv2_tiny Auto support

Line 309 marks Auto as ✅ for swinv2_tiny_window8_256, but tests explicitly skip that combo (tests/examples/torch_onnx/test_torch_quant_to_onnx.py Line 35). Please mark it unsupported (or add a footnote with the current limitation).

📝 Suggested docs correction
-| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
| Model | FP8 | INT8 | MXFP8 | NVFP4 | INT4_AWQ | Auto |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) |||||||
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) |||||||
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) |||||| |
| Model | FP8 | INT8 | MXFP8 | NVFP4 | INT4_AWQ | Auto |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) |||||||
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) |||||||
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) |||||| |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/README.md` around lines 305 - 309, Update the support
matrix row for the model identifier "swinv2_tiny_window8_256" in README.md to
remove the Auto ✅ (change to ❌) or add a footnote explaining it's currently
unsupported; reference the test that skips this combo
(test_torch_quant_to_onnx.py) as the reason for the change so readers know the
limitation is intentional.


## Resources

Expand Down
90 changes: 86 additions & 4 deletions examples/torch_onnx/torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
# limitations under the License.

import argparse
import copy
import json
import re
import sys
import warnings
from pathlib import Path

# Add onnx_ptq to path for shared modules
Expand Down Expand Up @@ -44,20 +47,69 @@

mp.set_start_method("spawn", force=True) # Needed for data loader with multiple workers

QUANT_CONFIG_DICT = {
QUANT_CONFIG_DICT: dict[str, dict] = {
"fp8": mtq.FP8_DEFAULT_CFG,
"int8": mtq.INT8_DEFAULT_CFG,
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
}

_FP8_CONV_OVERRIDE: list = [
{
"parent_class": "nn.Conv2d",
"quantizer_name": "*weight_quantizer",
"cfg": {"num_bits": (4, 3), "axis": None},
},
{
"parent_class": "nn.Conv2d",
"quantizer_name": "*input_quantizer",
"cfg": {"num_bits": (4, 3), "axis": None},
},
]

_INT8_CONV_OVERRIDE: list = [
{
"parent_class": "nn.Conv2d",
"quantizer_name": "*weight_quantizer",
"cfg": {"num_bits": 8, "axis": 0},
},
{
"parent_class": "nn.Conv2d",
"quantizer_name": "*input_quantizer",
"cfg": {"num_bits": 8, "axis": None},
},
]


def get_quant_config(quantize_mode):
"""Get quantization config, overriding Conv2d for TRT compatibility.

TensorRT only supports FP8 and INT8 for Conv layers.
- For MXFP8, NVFP4: override Conv2d to FP8
- For INT4_AWQ: override Conv2d to INT8
"""
config: dict = copy.deepcopy(QUANT_CONFIG_DICT[quantize_mode])
if quantize_mode in ("mxfp8", "nvfp4"):
warnings.warn(
f"TensorRT only supports FP8/INT8 for Conv layers. "
f"Overriding Conv2d quantization to FP8 for '{quantize_mode}' mode."
)
config["quant_cfg"].extend(_FP8_CONV_OVERRIDE)
elif quantize_mode == "int4_awq":
warnings.warn(
"TensorRT only supports FP8/INT8 for Conv layers. "
"Overriding Conv2d quantization to INT8 for 'int4_awq' mode."
)
config["quant_cfg"].extend(_INT8_CONV_OVERRIDE)
return config


def filter_func(name):
"""Filter function to exclude certain layers from quantization."""
pattern = re.compile(
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|"
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed).*"
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*"
)
return pattern.match(name) is not None

Expand Down Expand Up @@ -121,6 +173,18 @@ def loss_func(output, batch):
return F.cross_entropy(output, batch["label"])


def _disable_inplace_relu(model):
"""Replace inplace ReLU with non-inplace ReLU throughout the model.

This is needed for auto_quantize which uses backward hooks for gradient-based
sensitivity scoring. Inplace ReLU on views created by custom Functions causes
PyTorch autograd errors.
"""
for module in model.modules():
if isinstance(module, torch.nn.ReLU) and module.inplace:
module.inplace = False


def auto_quantize_model(
model,
data_loader,
Expand All @@ -142,6 +206,7 @@ def auto_quantize_model(
Returns:
Tuple of (quantized_model, search_state_dict)
"""
_disable_inplace_relu(model)
constraints = {"effective_bits": effective_bits}

# Convert string format names to actual config objects
Expand Down Expand Up @@ -255,12 +320,29 @@ def main():
default=128,
help="Number of scoring steps for auto quantization. Default is 128.",
)
parser.add_argument(
"--no_pretrained",
action="store_true",
help="Don't load pretrained weights (useful for testing with random weights).",
)
parser.add_argument(
"--model_kwargs",
type=str,
default=None,
help="JSON string of extra model kwargs (e.g., '{\"depth\": 1}').",
)

args = parser.parse_args()

# Create model and move to appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device)
model_kwargs = json.loads(args.model_kwargs) if args.model_kwargs else {}
model = timm.create_model(
args.timm_model_name,
pretrained=not args.no_pretrained,
num_classes=1000,
**model_kwargs,
).to(device)

# Get input shape from model config
input_size = get_model_input_shape(model)
Expand Down Expand Up @@ -297,7 +379,7 @@ def main():
)
else:
# Standard quantization - only load calibration data if needed
config = QUANT_CONFIG_DICT[args.quantize_mode]
config = get_quant_config(args.quantize_mode)
if args.quantize_mode == "mxfp8":
data_loader = None
else:
Expand Down
3 changes: 2 additions & 1 deletion modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,8 @@ def get_onnx_bytes_and_metadata(
op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
)
# Change FP32 cast nodes feeding into Concat/Add to FP16
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add", "Sqrt"])
op_list = ["Concat", "Add", "Sqrt", "LayerNormalization", "Clip", "Mul", "Exp"]
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, op_list)
else:
onnx_opt_graph = convert_to_f16(
onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False
Expand Down
5 changes: 4 additions & 1 deletion tests/_test_utils/torch/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def forward(self, x):
def _create_timm_fn(name):
def get_model_and_input(on_gpu: bool = False):
model = timm.create_model(name)
return process_model_and_inputs(model, (torch.randn(1, 3, 224, 224),), {}, on_gpu)
data_config = timm.data.resolve_model_data_config(model)
input_size = data_config["input_size"] # e.g., (3, 224, 224)
return process_model_and_inputs(model, (torch.randn(1, *input_size),), {}, on_gpu)

return get_model_and_input

Expand Down Expand Up @@ -114,6 +116,7 @@ def get_model_and_input(on_gpu: bool = False):
# "vovnet39a",
# "dm_nfnet_f0",
"efficientnet_b0",
"swin_tiny_patch4_window7_224",
],
_create_timm_fn,
),
Expand Down
69 changes: 54 additions & 15 deletions tests/examples/torch_onnx/test_torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,67 @@
# limitations under the License.


import os
import subprocess

import pytest
from _test_utils.examples.run_command import extend_cmd_parts, run_example_command

# TODO: Add int4_awq once the INT4 exporter supports non-MatMul/Gemm consumer patterns
# (e.g., DQ -> Reshape -> Slice in small ViT / SwinTransformer ONNX graphs).
_QUANT_MODES = ["fp8", "int8", "mxfp8", "nvfp4", "auto"]

_MODELS = {
"vit_tiny": ("vit_tiny_patch16_224", '{"depth": 1}'),
"swin_tiny": ("swin_tiny_patch4_window7_224", '{"depths": [1, 1, 1, 1]}'),
"swinv2_tiny": ("swinv2_tiny_window8_256", '{"depths": [1, 1, 1, 1]}'),
}

# Builder optimization level: 4 for low-bit modes, 3 otherwise
_LOW_BIT_MODES = {"fp8", "int8", "nvfp4"}


def _verify_trt_engine_build(onnx_save_path, quantize_mode):
"""Verify the exported ONNX model can be compiled into a TensorRT engine."""
example_dir = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx"
)
onnx_path = os.path.join(example_dir, onnx_save_path)
assert os.path.exists(onnx_path), f"ONNX file not found: {onnx_path}"

opt_level = "4" if quantize_mode in _LOW_BIT_MODES else "3"
cmd = [
"trtexec",
f"--onnx={onnx_path}",
"--stronglyTyped",
f"--builderOptimizationLevel={opt_level}",
]

result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) # nosec
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
assert result.returncode == 0, (
f"TensorRT engine build failed for {onnx_save_path} "
f"(mode={quantize_mode}):\n{result.stdout}\n{result.stderr}"
)
Comment on lines +37 to +57
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether a skip guard for trtexec exists in this test file.
rg -n 'def _verify_trt_engine_build|subprocess\.run|shutil\.which\("trtexec"\)|pytest\.skip' tests/examples/torch_onnx/test_torch_quant_to_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 210


🏁 Script executed:

#!/bin/bash
# Check the full test file structure to understand the test matrix and any skip conditions
head -80 tests/examples/torch_onnx/test_torch_quant_to_onnx.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 3684


🏁 Script executed:

#!/bin/bash
# Search for pytest decorators, fixtures, and skip conditions across the entire file
rg -n '@pytest\.|@skip|pytest\.mark|skipif|@fixture|parametrize' tests/examples/torch_onnx/test_torch_quant_to_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 180


🏁 Script executed:

#!/bin/bash
# Check if trtexec availability is checked anywhere at module level or in setup
rg -n 'trtexec|shutil\.which|which.*trtexec|importorskip' tests/examples/torch_onnx/test_torch_quant_to_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 87


Guard trtexec availability to prevent hard failure when TensorRT CLI is unavailable.

The function _verify_trt_engine_build() at line 53 invokes subprocess.run() to execute trtexec without checking if the tool exists in PATH. This will raise FileNotFoundError in test environments lacking TensorRT CLI instead of producing a controlled skip. Additionally, the parametrized test matrix (3 models × 5 quantize modes = 15 test cases), with each case running ONNX export plus a 10-minute TensorRT engine build, likely exceeds the "few minutes" guideline for integration tests in tests/examples/.

Add an availability check at the start of _verify_trt_engine_build():

Suggested fix
+import shutil
 import subprocess
@@
 def _verify_trt_engine_build(onnx_save_path, quantize_mode):
     """Verify the exported ONNX model can be compiled into a TensorRT engine."""
+    if shutil.which("trtexec") is None:
+        pytest.skip("Skipping: `trtexec` is not available in PATH.")
+
     example_dir = os.path.join(
         os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx"
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/examples/torch_onnx/test_torch_quant_to_onnx.py` around lines 37 - 57,
The test should guard against missing TensorRT CLI by checking for trtexec
before attempting subprocess.run: in _verify_trt_engine_build(), use
shutil.which("trtexec") and if it returns None call pytest.skip("trtexec not
available; skipping TensorRT engine build") to avoid FileNotFoundError;
additionally wrap the subprocess.run call in a try/except FileNotFoundError that
also calls pytest.skip with the same message and consider lowering or making the
timeout configurable (the subprocess.run call with timeout=600) to avoid very
long test runs for the heavy parametrized matrix.



@pytest.mark.parametrize("quantize_mode", _QUANT_MODES)
@pytest.mark.parametrize("model_key", list(_MODELS))
def test_torch_onnx(model_key, quantize_mode):
timm_model_name, model_kwargs = _MODELS[model_key]
onnx_save_path = f"{model_key}.{quantize_mode}.onnx"

# TODO: Add accuracy evaluation after we upgrade TRT version to 10.12
@pytest.mark.parametrize(
("quantize_mode", "onnx_save_path", "calib_size", "num_score_steps"),
[
("fp8", "vit_base_patch16_224.fp8.onnx", "1", "1"),
("int8", "vit_base_patch16_224.int8.onnx", "1", "1"),
("nvfp4", "vit_base_patch16_224.nvfp4.onnx", "1", "1"),
("mxfp8", "vit_base_patch16_224.mxfp8.onnx", "1", "1"),
("int4_awq", "vit_base_patch16_224.int4_awq.onnx", "1", "1"),
("auto", "vit_base_patch16_224.auto.onnx", "1", "1"),
],
)
def test_torch_onnx(quantize_mode, onnx_save_path, calib_size, num_score_steps):
# Step 1: Quantize and export to ONNX
cmd_parts = extend_cmd_parts(
["python", "torch_quant_to_onnx.py"],
timm_model_name=timm_model_name,
model_kwargs=model_kwargs,
quantize_mode=quantize_mode,
onnx_save_path=onnx_save_path,
calibration_data_size=calib_size,
num_score_steps=num_score_steps,
calibration_data_size="1",
num_score_steps="1",
)
cmd_parts.append("--no_pretrained")
run_example_command(cmd_parts, "torch_onnx")

# Step 2: Verify the exported ONNX model builds a TensorRT engine
_verify_trt_engine_build(onnx_save_path, quantize_mode)
Loading