Skip to content

Commit 085d684

Browse files
authored
Merge branch 'main' into jingyux/implicit-gemm-nvfp4-e2e
2 parents 47b036f + 73be810 commit 085d684

24 files changed

Lines changed: 514 additions & 129 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1059,7 +1059,7 @@ def quantize_main(
10591059
assert isinstance(recipe, ModelOptPTQRecipe), (
10601060
f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}"
10611061
)
1062-
quant_cfg = recipe.quantize
1062+
quant_cfg = recipe.quantize.model_dump()
10631063

10641064
else:
10651065
assert len(args.qformat.split(",")) == 1, (

examples/onnx_ptq/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ Most of the examples in this doc use `vit_base_patch16_224.onnx` as the input mo
5656

5757
```bash
5858
python download_example_onnx.py \
59-
--vit \
59+
--timm_model_name=vit_base_patch16_224 \
6060
--onnx_save_path=vit_base_patch16_224.onnx \
6161
--fp16 # <Optional, if the desired output ONNX precision is FP16>
6262
```

examples/onnx_ptq/download_example_onnx.py

Lines changed: 20 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import argparse
1717
import os
18-
import subprocess
1918

2019
import timm
2120
import torch
@@ -46,14 +45,10 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
4645
parser = argparse.ArgumentParser(description="Download and export example models to ONNX.")
4746

4847
parser.add_argument(
49-
"--vit",
50-
action="store_true",
51-
help="Export timm/vit_base_patch16_224 model to ONNX.",
52-
)
53-
parser.add_argument(
54-
"--llama",
55-
action="store_true",
56-
help="Export meta-llama/Llama-3.1-8B-Instruct to ONNX with KV cache.",
48+
"--timm_model_name",
49+
type=str,
50+
required=True,
51+
help="Export any timm model to ONNX (e.g., vit_base_patch16_224, swin_tiny_patch4_window7_224).",
5752
)
5853
parser.add_argument(
5954
"--onnx_save_path", type=str, required=False, help="Path to save the final ONNX model."
@@ -62,7 +57,7 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
6257
"--batch_size",
6358
type=int,
6459
default=1,
65-
help="Batch size for the exported ViT model.",
60+
help="Batch size for the exported model.",
6661
)
6762
parser.add_argument(
6863
"--fp16",
@@ -71,54 +66,18 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
7166
)
7267
args = parser.parse_args()
7368

74-
if args.vit:
75-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76-
model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=1000).to(
77-
device
78-
)
79-
data_config = timm.data.resolve_model_data_config(model)
80-
input_shape = (args.batch_size,) + data_config["input_size"]
81-
82-
vit_save_path = args.onnx_save_path or "vit_base_patch16_224.onnx"
83-
weights_dtype = "fp16" if args.fp16 else "fp32"
84-
export_to_onnx(
85-
model,
86-
input_shape,
87-
vit_save_path,
88-
device,
89-
weights_dtype=weights_dtype,
90-
)
91-
print(f"ViT model exported to {vit_save_path}")
92-
93-
if args.llama:
94-
model_name = "meta-llama/Llama-3.1-8B-Instruct"
95-
if not args.onnx_save_path:
96-
args.onnx_save_path = "Llama-3.1-8B-Instruct/model.onnx"
97-
98-
output_dir = os.path.dirname(args.onnx_save_path)
99-
if not output_dir: # Handle cases where only filename is given (save in current dir)
100-
output_dir = "."
101-
os.makedirs(output_dir, exist_ok=True)
102-
103-
command = [
104-
"python",
105-
"-m",
106-
"optimum.commands.optimum_cli",
107-
"export",
108-
"onnx",
109-
"--model",
110-
model_name,
111-
"--task",
112-
"causal-lm-with-past",
113-
"--device",
114-
"cuda",
115-
"--fp16" if args.fp16 else "",
116-
output_dir,
117-
]
118-
119-
try:
120-
print(f"Running optimum-cli export to {output_dir}...")
121-
subprocess.run(command, check=True, capture_output=True, text=True, encoding="utf-8")
122-
print(f"Llama model exported to {output_dir}")
123-
except subprocess.CalledProcessError as e:
124-
raise RuntimeError(f"Failed to export model: {e.stderr}")
69+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70+
model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device)
71+
data_config = timm.data.resolve_model_data_config(model)
72+
input_shape = (args.batch_size,) + data_config["input_size"]
73+
74+
save_path = args.onnx_save_path or f"{args.timm_model_name}.onnx"
75+
weights_dtype = "fp16" if args.fp16 else "fp32"
76+
export_to_onnx(
77+
model,
78+
input_shape,
79+
save_path,
80+
device,
81+
weights_dtype=weights_dtype,
82+
)
83+
print(f"{args.timm_model_name} model exported to {save_path}")

examples/speculative_decoding/main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
import modelopt.torch.opt as mto
5050
import modelopt.torch.speculative as mtsp
51+
from modelopt.torch.speculative.config import EagleConfig
5152
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
5253
from modelopt.torch.utils import print_rank_0
5354

@@ -266,8 +267,11 @@ def train():
266267
}
267268
mtsp.convert(model, [("medusa", config)])
268269
elif training_args.mode == "eagle3":
269-
# eagle_cfg maps directly to EagleConfig fields; eagle_offline is derived here.
270-
eagle_cfg["eagle_offline"] = use_offline_training
270+
# Validate and rewrite eagle config fields
271+
eagle_cfg = EagleConfig.model_validate(
272+
eagle_cfg,
273+
context={"training_args": training_args, "data_args": data_args},
274+
).model_dump()
271275
mtsp.convert(model, [("eagle", eagle_cfg)])
272276

273277
# Load draft vocab cache if the draft model uses a compressed vocabulary

examples/torch_onnx/README.md

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf
5353

5454
- Loads a pretrained timm torch model (default: ViT-Base).
5555
- Quantizes the torch model to FP8, MXFP8, INT8, NVFP4, or INT4_AWQ using ModelOpt.
56+
- For models with Conv2d layers (e.g., SwinTransformer), automatically overrides Conv2d quantization to FP8 (for MXFP8/NVFP4 modes) or INT8 (for INT4_AWQ mode) for TensorRT compatibility.
5657
- Exports the quantized model to ONNX.
5758
- Postprocesses the ONNX model to be compatible with TensorRT.
5859
- Saves the final ONNX model.
@@ -63,11 +64,21 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf
6364

6465
```bash
6566
python torch_quant_to_onnx.py \
66-
--timm_model_name=vit_base_patch16_224 \
67+
--timm_model_name=<timm model name> \
6768
--quantize_mode=<fp8|mxfp8|int8|nvfp4|int4_awq> \
6869
--onnx_save_path=<path to save the exported ONNX model>
6970
```
7071

72+
### Conv2d Quantization Override
73+
74+
TensorRT only supports FP8 and INT8 for convolution operations. When quantizing models with Conv2d layers (like SwinTransformer), the script automatically applies the following overrides:
75+
76+
| Quantize Mode | Conv2d Override | Reason |
77+
| :---: | :---: | :--- |
78+
| FP8, INT8 | None (already compatible) | Native TRT support |
79+
| MXFP8, NVFP4 | Conv2d -> FP8 | TRT Conv limitation |
80+
| INT4_AWQ | Conv2d -> INT8 | TRT Conv limitation |
81+
7182
### Evaluation
7283

7384
If the input model is of type image classification, use the following script to evaluate it. The script automatically downloads and uses the [ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset from Hugging Face. This gated repository requires authentication via Hugging Face access token. See <https://huggingface.co/docs/hub/en/security-tokens> for details.
@@ -79,7 +90,7 @@ python ../onnx_ptq/evaluate.py \
7990
--onnx_path=<path to the exported ONNX model> \
8091
--imagenet_path=<HF dataset card or local path to the ImageNet dataset> \
8192
--engine_precision=stronglyTyped \
82-
--model_name=vit_base_patch16_224
93+
--model_name=<timm model name>
8394
```
8495

8596
## LLM Quantization and Export with TensorRT-Edge-LLM
@@ -289,13 +300,13 @@ python torch_quant_to_onnx.py \
289300
--onnx_save_path=vit_base_patch16_224.auto_quant.onnx
290301
```
291302

292-
### Results (ViT-Base)
303+
## ONNX Export Supported Vision Models
293304

294-
| | Top-1 accuracy (torch) | Top-5 accuracy (torch) |
295-
| :--- | :---: | :---: |
296-
| Torch autocast (FP16) | 85.11% | 97.53% |
297-
| NVFP4 Quantized | 84.558% | 97.36% |
298-
| Auto Quantized (FP8 + NVFP4, 4.78 effective bits) | 84.726% | 97.434% |
305+
| Model | FP8 | INT8 | MXFP8 | NVFP4 | INT4_AWQ | Auto |
306+
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
307+
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | | |||| |
308+
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) || | ||| |
309+
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) ||||| | |
299310

300311
## Resources
301312

examples/torch_onnx/torch_quant_to_onnx.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
# limitations under the License.
1515

1616
import argparse
17+
import copy
18+
import json
1719
import re
1820
import sys
21+
import warnings
1922
from pathlib import Path
2023

2124
# Add onnx_ptq to path for shared modules
@@ -44,20 +47,69 @@
4447

4548
mp.set_start_method("spawn", force=True) # Needed for data loader with multiple workers
4649

47-
QUANT_CONFIG_DICT = {
50+
QUANT_CONFIG_DICT: dict[str, dict] = {
4851
"fp8": mtq.FP8_DEFAULT_CFG,
4952
"int8": mtq.INT8_DEFAULT_CFG,
5053
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
5154
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
5255
"int4_awq": mtq.INT4_AWQ_CFG,
5356
}
5457

58+
_FP8_CONV_OVERRIDE: list = [
59+
{
60+
"parent_class": "nn.Conv2d",
61+
"quantizer_name": "*weight_quantizer",
62+
"cfg": {"num_bits": (4, 3), "axis": None},
63+
},
64+
{
65+
"parent_class": "nn.Conv2d",
66+
"quantizer_name": "*input_quantizer",
67+
"cfg": {"num_bits": (4, 3), "axis": None},
68+
},
69+
]
70+
71+
_INT8_CONV_OVERRIDE: list = [
72+
{
73+
"parent_class": "nn.Conv2d",
74+
"quantizer_name": "*weight_quantizer",
75+
"cfg": {"num_bits": 8, "axis": 0},
76+
},
77+
{
78+
"parent_class": "nn.Conv2d",
79+
"quantizer_name": "*input_quantizer",
80+
"cfg": {"num_bits": 8, "axis": None},
81+
},
82+
]
83+
84+
85+
def get_quant_config(quantize_mode):
86+
"""Get quantization config, overriding Conv2d for TRT compatibility.
87+
88+
TensorRT only supports FP8 and INT8 for Conv layers.
89+
- For MXFP8, NVFP4: override Conv2d to FP8
90+
- For INT4_AWQ: override Conv2d to INT8
91+
"""
92+
config: dict = copy.deepcopy(QUANT_CONFIG_DICT[quantize_mode])
93+
if quantize_mode in ("mxfp8", "nvfp4"):
94+
warnings.warn(
95+
f"TensorRT only supports FP8/INT8 for Conv layers. "
96+
f"Overriding Conv2d quantization to FP8 for '{quantize_mode}' mode."
97+
)
98+
config["quant_cfg"].extend(_FP8_CONV_OVERRIDE)
99+
elif quantize_mode == "int4_awq":
100+
warnings.warn(
101+
"TensorRT only supports FP8/INT8 for Conv layers. "
102+
"Overriding Conv2d quantization to INT8 for 'int4_awq' mode."
103+
)
104+
config["quant_cfg"].extend(_INT8_CONV_OVERRIDE)
105+
return config
106+
55107

56108
def filter_func(name):
57109
"""Filter function to exclude certain layers from quantization."""
58110
pattern = re.compile(
59111
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|"
60-
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed).*"
112+
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*"
61113
)
62114
return pattern.match(name) is not None
63115

@@ -121,6 +173,18 @@ def loss_func(output, batch):
121173
return F.cross_entropy(output, batch["label"])
122174

123175

176+
def _disable_inplace_relu(model):
177+
"""Replace inplace ReLU with non-inplace ReLU throughout the model.
178+
179+
This is needed for auto_quantize which uses backward hooks for gradient-based
180+
sensitivity scoring. Inplace ReLU on views created by custom Functions causes
181+
PyTorch autograd errors.
182+
"""
183+
for module in model.modules():
184+
if isinstance(module, torch.nn.ReLU) and module.inplace:
185+
module.inplace = False
186+
187+
124188
def auto_quantize_model(
125189
model,
126190
data_loader,
@@ -142,6 +206,7 @@ def auto_quantize_model(
142206
Returns:
143207
Tuple of (quantized_model, search_state_dict)
144208
"""
209+
_disable_inplace_relu(model)
145210
constraints = {"effective_bits": effective_bits}
146211

147212
# Convert string format names to actual config objects
@@ -255,12 +320,29 @@ def main():
255320
default=128,
256321
help="Number of scoring steps for auto quantization. Default is 128.",
257322
)
323+
parser.add_argument(
324+
"--no_pretrained",
325+
action="store_true",
326+
help="Don't load pretrained weights (useful for testing with random weights).",
327+
)
328+
parser.add_argument(
329+
"--model_kwargs",
330+
type=str,
331+
default=None,
332+
help="JSON string of extra model kwargs (e.g., '{\"depth\": 1}').",
333+
)
258334

259335
args = parser.parse_args()
260336

261337
# Create model and move to appropriate device
262338
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
263-
model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device)
339+
model_kwargs = json.loads(args.model_kwargs) if args.model_kwargs else {}
340+
model = timm.create_model(
341+
args.timm_model_name,
342+
pretrained=not args.no_pretrained,
343+
num_classes=1000,
344+
**model_kwargs,
345+
).to(device)
264346

265347
# Get input shape from model config
266348
input_size = get_model_input_shape(model)
@@ -297,7 +379,7 @@ def main():
297379
)
298380
else:
299381
# Standard quantization - only load calibration data if needed
300-
config = QUANT_CONFIG_DICT[args.quantize_mode]
382+
config = get_quant_config(args.quantize_mode)
301383
if args.quantize_mode == "mxfp8":
302384
data_loader = None
303385
else:

examples/vllm_serve/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ You can either edit the `quant_config` dictionary in `vllm_serve_fakequant.py`,
2828
| QUANT_FILE_PATH | Optional path to exported quantizer state dict `quantizer_state.pth` | None |
2929
| MODELOPT_STATE_PATH | Optional path to exported `vllm_fq_modelopt_state.pth` (restores quantizer state and parameters) | None |
3030
| CALIB_BATCH_SIZE | Calibration batch size | 1 |
31+
| RECIPE_PATH | Optional path to a ModelOpt PTQ recipe YAML | None |
3132

3233
Set these variables in your shell or Docker environment as needed to customize calibration.
3334

@@ -65,7 +66,7 @@ Step 1: export the model with bf16 weights and quantizer state. To export the mo
6566
```bash
6667
python ../llm_ptq/hf_ptq.py \
6768
--pyt_ckpt_path <MODEL_PATH> \
68-
--qformat nvfp4 \
69+
--recipe <PATH_TO_RECIPE> \
6970
--calib_size 512 \
7071
--export_path <EXPORT_DIR> \
7172
--vllm_fakequant_export \

examples/vllm_serve/fakequant_worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"quant_file_path": os.environ.get("QUANT_FILE_PATH", None),
4444
"modelopt_state_path": os.environ.get("MODELOPT_STATE_PATH", None),
4545
"calib_batch_size": int(os.environ.get("CALIB_BATCH_SIZE", 1)),
46+
"recipe_path": os.environ.get("RECIPE_PATH", None),
4647
}
4748

4849

@@ -138,6 +139,7 @@ def compile_or_warm_up_model(self) -> None:
138139
quant_config["quant_cfg"]
139140
or quant_config["kv_quant_cfg"]
140141
or quant_config["modelopt_state_path"]
142+
or quant_config["recipe_path"]
141143
):
142144
_fakequant_run_prolog_worker(self)
143145
super().compile_or_warm_up_model()

0 commit comments

Comments
 (0)