Skip to content

Commit f6f62a3

Browse files
committed
Add FP8 MHA quantization support for HuggingFace ViT
Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar transformer vision models) when exported to ONNX with FP8 Q/DQ. - fp8_exporter: rewrite attention-scaling Mul and K Transpose to the Q-side so DQ feeds MatMul directly, pre-transpose weight constants, insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype now matches the graph's float dtype to keep strongly-typed builds consistent. - onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16 inserts around Q/DQ by rewriting scale initializers to FP16, so TRT fuses DQ into the downstream GEMM/MatMul kernel. - torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native input dtype so no Cast is injected between graph and Q/DQ. - torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry so LayerNorm output quantizers are honored. - torch/quantization/plugins/huggingface: skip attention wrappers whose children are also "*Attention" to avoid double-patching eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention). Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8 config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer, disabled input quantizers on LayerNorm-followed layers, and *_bmm_quantizer entries) plus accuracy + TRT-latency comparison against an FP16 baseline. Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1): - Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs 80.96% / 95.44% (torch FP8) — -0.20% / -0.06% - TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 010b220 commit f6f62a3

13 files changed

Lines changed: 733 additions & 42 deletions

File tree

CHANGELOG.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
Changelog
22
=========
33

4+
0.45 (2026-06-xx)
5+
^^^^^^^^^^^^^^^^^
6+
7+
**New Features**
8+
9+
- Add FP8 MHA quantization support for vision transformers. Adds an attention-aware ONNX post-processing pass (scale Mul / K-transpose move before Q, Q→DQ insertion on softmax output) in :class:`FP8QuantExporter <modelopt.onnx.export.fp8_exporter.FP8QuantExporter>`, per-instance nested-attention-wrapper skipping in the HF plugin, and ``nn.LayerNorm`` registration in ``QuantModuleRegistry`` so BMM input quantizers and LayerNorm output quantizers defined in FP8_DEFAULT_CFG are honored end-to-end. See `examples/torch_onnx/torch_quant_to_onnx.py <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/torch_onnx/torch_quant_to_onnx.py>`_ for the general timm-model quantize→ONNX workflow.
10+
411
0.44 (2026-05-xx)
512
^^^^^^^^^^^^^^^^^
613

examples/torch_onnx/torch_quant_to_onnx.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,22 @@
8888
},
8989
]
9090

91+
# FP8 MHA-aware config entries: quantize LayerNorm output so TRT can fuse the shared
92+
# Q/DQ across all downstream Q/K/V/FC consumers. Softmax-output Q/DQ is handled by the
93+
# FP8 ONNX exporter's post-processing pass (fixed 1/448 scale, data-independent).
94+
_FP8_MHA_OVERRIDE: list = [
95+
{
96+
"parent_class": "nn.LayerNorm",
97+
"quantizer_name": "*output_quantizer",
98+
"cfg": {"num_bits": (4, 3), "axis": None},
99+
},
100+
{
101+
"parent_class": "nn.LayerNorm",
102+
"quantizer_name": "*input_quantizer",
103+
"enable": False,
104+
},
105+
]
106+
91107
# Auto-quantize format configs that use block quantization and need Conv2d overrides for TRT.
92108
# TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
93109
_NEEDS_FP8_CONV_OVERRIDE: set[str] = {
@@ -102,11 +118,16 @@ def get_quant_config(quantize_mode):
102118
"""Get quantization config, overriding Conv2d for TRT compatibility.
103119
104120
TensorRT only supports FP8 and INT8 for Conv layers.
121+
- For FP8: add MHA-aware LayerNorm output quantizer so TRT fuses shared Q/DQ into
122+
downstream attention matmuls. Softmax-output Q/DQ is inserted by the FP8 ONNX
123+
exporter's post-processing (fixed 1/448 scale, no calibration needed).
105124
- For MXFP8, NVFP4: override Conv2d to FP8
106125
- For INT4_AWQ: override Conv2d to INT8
107126
"""
108127
config: dict = copy.deepcopy(QUANT_CONFIG_DICT[quantize_mode])
109-
if quantize_mode in ("mxfp8", "nvfp4"):
128+
if quantize_mode == "fp8":
129+
config["quant_cfg"].extend(_FP8_MHA_OVERRIDE)
130+
elif quantize_mode in ("mxfp8", "nvfp4"):
110131
warnings.warn(
111132
f"TensorRT only supports FP8/INT8 for Conv layers. "
112133
f"Overriding Conv2d quantization to FP8 for '{quantize_mode}' mode."
@@ -458,6 +479,7 @@ def main():
458479
# Conv2d layers are overridden to FP8 (for TRT compatibility), those FP8
459480
# quantizers require calibration data.
460481
config = get_quant_config(args.quantize_mode)
482+
461483
data_loader = load_calibration_data(
462484
model,
463485
args.calibration_data_size,

0 commit comments

Comments
 (0)