Skip to content

Commit 48d8486

Browse files
committed
Add FP8 MHA quantization support for HuggingFace ViT
Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar transformer vision models) when exported to ONNX with FP8 Q/DQ. - fp8_exporter: rewrite attention-scaling Mul and K Transpose to the Q-side so DQ feeds MatMul directly, pre-transpose weight constants, insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype now matches the graph's float dtype to keep strongly-typed builds consistent. - onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16 inserts around Q/DQ by rewriting scale initializers to FP16, so TRT fuses DQ into the downstream GEMM/MatMul kernel. - torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native input dtype so no Cast is injected between graph and Q/DQ. - torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry so LayerNorm output quantizers are honored. - torch/quantization/plugins/huggingface: skip attention wrappers whose children are also "*Attention" to avoid double-patching eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention). Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8 config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer, disabled input quantizers on LayerNorm-followed layers, and *_bmm_quantizer entries) plus accuracy + TRT-latency comparison against an FP16 baseline. Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1): - Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs 80.96% / 95.44% (torch FP8) — -0.20% / -0.06% - TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 010b220 commit 48d8486

13 files changed

Lines changed: 932 additions & 42 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Changelog
1717
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
1818
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml>`_ for usage.
1919
- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning.
20+
- Add FP8 MHA quantization support for vision transformers. Adds an attention-aware ONNX post-processing pass (scale Mul / K-transpose move before Q, Q→DQ insertion on softmax output) in :class:`FP8QuantExporter <modelopt.onnx.export.fp8_exporter.FP8QuantExporter>`, per-instance nested-attention-wrapper skipping in the HF plugin, and ``nn.LayerNorm`` registration in ``QuantModuleRegistry`` so BMM input quantizers and LayerNorm output quantizers defined in FP8_DEFAULT_CFG are honored end-to-end. See `examples/torch_onnx/torch_quant_to_onnx.py <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/torch_onnx/torch_quant_to_onnx.py>`_ for the general timm-model quantize→ONNX workflow.
2021

2122
**Backward Breaking Changes**
2223

examples/torch_onnx/torch_quant_to_onnx.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,22 @@
8888
},
8989
]
9090

91+
# FP8 MHA-aware config entries: quantize LayerNorm output so TRT can fuse the shared
92+
# Q/DQ across all downstream Q/K/V/FC consumers. Softmax-output Q/DQ is handled by the
93+
# FP8 ONNX exporter's post-processing pass (fixed 1/448 scale, data-independent).
94+
_FP8_MHA_OVERRIDE: list = [
95+
{
96+
"parent_class": "nn.LayerNorm",
97+
"quantizer_name": "*output_quantizer",
98+
"cfg": {"num_bits": (4, 3), "axis": None},
99+
},
100+
{
101+
"parent_class": "nn.LayerNorm",
102+
"quantizer_name": "*input_quantizer",
103+
"enable": False,
104+
},
105+
]
106+
91107
# Auto-quantize format configs that use block quantization and need Conv2d overrides for TRT.
92108
# TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
93109
_NEEDS_FP8_CONV_OVERRIDE: set[str] = {
@@ -102,11 +118,16 @@ def get_quant_config(quantize_mode):
102118
"""Get quantization config, overriding Conv2d for TRT compatibility.
103119
104120
TensorRT only supports FP8 and INT8 for Conv layers.
121+
- For FP8: add MHA-aware LayerNorm output quantizer so TRT fuses shared Q/DQ into
122+
downstream attention matmuls. Softmax-output Q/DQ is inserted by the FP8 ONNX
123+
exporter's post-processing (fixed 1/448 scale, no calibration needed).
105124
- For MXFP8, NVFP4: override Conv2d to FP8
106125
- For INT4_AWQ: override Conv2d to INT8
107126
"""
108127
config: dict = copy.deepcopy(QUANT_CONFIG_DICT[quantize_mode])
109-
if quantize_mode in ("mxfp8", "nvfp4"):
128+
if quantize_mode == "fp8":
129+
config["quant_cfg"].extend(_FP8_MHA_OVERRIDE)
130+
elif quantize_mode in ("mxfp8", "nvfp4"):
110131
warnings.warn(
111132
f"TensorRT only supports FP8/INT8 for Conv layers. "
112133
f"Overriding Conv2d quantization to FP8 for '{quantize_mode}' mode."
@@ -458,6 +479,7 @@ def main():
458479
# Conv2d layers are overridden to FP8 (for TRT compatibility), those FP8
459480
# quantizers require calibration data.
460481
config = get_quant_config(args.quantize_mode)
482+
461483
data_loader = load_calibration_data(
462484
model,
463485
args.calibration_data_size,

0 commit comments

Comments
 (0)