Skip to content

Commit 928d417

Browse files
committed
Add FP8 MHA quantization support for HuggingFace ViT
Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar transformer vision models) when exported to ONNX with FP8 Q/DQ. - fp8_exporter: rewrite attention-scaling Mul and K Transpose to the Q-side so DQ feeds MatMul directly, pre-transpose weight constants, insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype now matches the graph's float dtype to keep strongly-typed builds consistent. - onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16 inserts around Q/DQ by rewriting scale initializers to FP16, so TRT fuses DQ into the downstream GEMM/MatMul kernel. - torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native input dtype so no Cast is injected between graph and Q/DQ. - torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry so LayerNorm output quantizers are honored. - torch/quantization/plugins/huggingface: skip attention wrappers whose children are also "*Attention" to avoid double-patching eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention). Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8 config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer, disabled input quantizers on LayerNorm-followed layers, and *_bmm_quantizer entries) plus accuracy + TRT-latency comparison against an FP16 baseline. Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1): - Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs 80.96% / 95.44% (torch FP8) — -0.20% / -0.06% - TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 010b220 commit 928d417

10 files changed

Lines changed: 488 additions & 42 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Changelog
1717
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
1818
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml>`_ for usage.
1919
- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning.
20+
- Add FP8 MHA quantization support for vision transformers. Adds an attention-aware ONNX post-processing pass (scale Mul / K-transpose move before Q, Q→DQ insertion on softmax output) in :class:`FP8QuantExporter <modelopt.onnx.export.fp8_exporter.FP8QuantExporter>`, per-instance nested-attention-wrapper skipping in the HF plugin, and ``nn.LayerNorm`` registration in ``QuantModuleRegistry`` so BMM input quantizers and LayerNorm output quantizers defined in FP8_DEFAULT_CFG are honored end-to-end. See `examples/torch_onnx/torch_quant_to_onnx.py <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/torch_onnx/torch_quant_to_onnx.py>`_ for the general timm-model quantize→ONNX workflow.
2021

2122
**Backward Breaking Changes**
2223

examples/torch_onnx/torch_quant_to_onnx.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,36 @@
8888
},
8989
]
9090

91+
# FP8 MHA-aware config entries: quantize LayerNorm output and Softmax output so TRT can
92+
# fuse Q/DQ into the attention MatMul kernels. LayerNorm output QDQ is shared across all
93+
# downstream Q/K/V/FC consumers; Softmax output QDQ is required for MHA-v2 fusion on the
94+
# attn@V MatMul. Relies on ``torch.nn.LayerNorm`` and ``torch.nn.Softmax`` being registered
95+
# in ``QuantModuleRegistry`` (see ``modelopt/torch/quantization/nn/modules``).
96+
_FP8_MHA_OVERRIDE: list = [
97+
{
98+
"parent_class": "nn.LayerNorm",
99+
"quantizer_name": "*output_quantizer",
100+
"cfg": {"num_bits": (4, 3), "axis": None},
101+
},
102+
{
103+
"parent_class": "nn.LayerNorm",
104+
"quantizer_name": "*input_quantizer",
105+
"enable": False,
106+
},
107+
{
108+
"parent_class": "nn.Softmax",
109+
"quantizer_name": "*output_quantizer",
110+
"cfg": {"num_bits": (4, 3), "axis": None},
111+
},
112+
{
113+
# Pre-softmax Q/DQ can't fuse into the Q@K^T MatMul (no TRT kernel for
114+
# MatMul→Softmax fusion through Q/DQ) and just adds overhead.
115+
"parent_class": "nn.Softmax",
116+
"quantizer_name": "*input_quantizer",
117+
"enable": False,
118+
},
119+
]
120+
91121
# Auto-quantize format configs that use block quantization and need Conv2d overrides for TRT.
92122
# TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
93123
_NEEDS_FP8_CONV_OVERRIDE: set[str] = {
@@ -102,11 +132,14 @@ def get_quant_config(quantize_mode):
102132
"""Get quantization config, overriding Conv2d for TRT compatibility.
103133
104134
TensorRT only supports FP8 and INT8 for Conv layers.
135+
- For FP8: add MHA-aware LayerNorm/Softmax output quantizers for transformer fusion.
105136
- For MXFP8, NVFP4: override Conv2d to FP8
106137
- For INT4_AWQ: override Conv2d to INT8
107138
"""
108139
config: dict = copy.deepcopy(QUANT_CONFIG_DICT[quantize_mode])
109-
if quantize_mode in ("mxfp8", "nvfp4"):
140+
if quantize_mode == "fp8":
141+
config["quant_cfg"].extend(_FP8_MHA_OVERRIDE)
142+
elif quantize_mode in ("mxfp8", "nvfp4"):
110143
warnings.warn(
111144
f"TensorRT only supports FP8/INT8 for Conv layers. "
112145
f"Overriding Conv2d quantization to FP8 for '{quantize_mode}' mode."
@@ -121,6 +154,67 @@ def get_quant_config(quantize_mode):
121154
return config
122155

123156

157+
def _inject_softmax_modules(model):
158+
"""Replace timm vision-transformer ``F.softmax`` with ``nn.Softmax`` submodules.
159+
160+
Timm's ``Attention.forward`` calls ``attn.softmax(dim=-1)`` (``F.softmax``), which
161+
exposes no module for quantization. We disable ``fused_attn`` (so the non-fused
162+
path runs) and attach a ``self.softmax = nn.Softmax(dim=-1)`` child, then rebind
163+
``forward`` to call that submodule. Combined with ``nn.Softmax`` registered in
164+
``QuantModuleRegistry``, this means a standard ``mtq.quantize`` pass will add
165+
and calibrate the softmax output quantizer.
166+
167+
Returns the count of patched attention modules.
168+
"""
169+
try:
170+
from timm.models.vision_transformer import Attention as _VitAttention
171+
except ImportError:
172+
return 0
173+
174+
patched = 0
175+
for _, module in model.named_modules():
176+
if not isinstance(module, _VitAttention):
177+
continue
178+
module.fused_attn = False
179+
if not isinstance(getattr(module, "softmax", None), torch.nn.Softmax):
180+
module.softmax = torch.nn.Softmax(dim=-1)
181+
module.forward = _vit_attention_forward.__get__(module, type(module))
182+
patched += 1
183+
return patched
184+
185+
186+
def _vit_attention_forward(self, x, attn_mask=None, is_causal=False):
187+
"""Replacement for timm ``Attention.forward`` that routes softmax through ``self.softmax``.
188+
189+
Mirrors the non-fused branch of upstream timm's implementation for the case without
190+
masking/causal (the default for image classifiers). ``self.softmax`` is a real
191+
``nn.Softmax`` module, so its output_quantizer is honored during quantization.
192+
"""
193+
B, N, C = x.shape
194+
qkv = (
195+
self.qkv(x)
196+
.reshape(B, N, 3, self.num_heads, self.head_dim)
197+
.permute(2, 0, 3, 1, 4)
198+
)
199+
q, k, v = qkv.unbind(0)
200+
q, k = self.q_norm(q), self.k_norm(k)
201+
202+
q = q * self.scale
203+
attn = q @ k.transpose(-2, -1)
204+
if attn_mask is not None:
205+
attn = attn + attn_mask
206+
attn = self.softmax(attn)
207+
attn = self.attn_drop(attn)
208+
x = attn @ v
209+
210+
x = x.transpose(1, 2).reshape(B, N, self.attn_dim)
211+
if hasattr(self, "norm"):
212+
x = self.norm(x)
213+
x = self.proj(x)
214+
x = self.proj_drop(x)
215+
return x
216+
217+
124218
def filter_func(name):
125219
"""Filter function to exclude certain layers from quantization.
126220
@@ -458,6 +552,14 @@ def main():
458552
# Conv2d layers are overridden to FP8 (for TRT compatibility), those FP8
459553
# quantizers require calibration data.
460554
config = get_quant_config(args.quantize_mode)
555+
556+
if args.quantize_mode == "fp8":
557+
# Swap timm Attention's internal F.softmax for an nn.Softmax submodule so
558+
# the output_quantizer declared in _FP8_MHA_OVERRIDE picks it up.
559+
n_patched = _inject_softmax_modules(model)
560+
if n_patched:
561+
print(f"Patched {n_patched} timm Attention modules for softmax output quantization")
562+
461563
data_loader = load_calibration_data(
462564
model,
463565
args.calibration_data_size,

0 commit comments

Comments
 (0)