Skip to content

Commit ef8c769

Browse files
committed
Add FP8 MHA quantization support for HuggingFace ViT
Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar transformer vision models) when exported to ONNX with FP8 Q/DQ. - fp8_exporter: rewrite attention-scaling Mul and K Transpose to the Q-side so DQ feeds MatMul directly, pre-transpose weight constants, insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype now matches the graph's float dtype to keep strongly-typed builds consistent. - onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16 inserts around Q/DQ by rewriting scale initializers to FP16, so TRT fuses DQ into the downstream GEMM/MatMul kernel. - torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native input dtype so no Cast is injected between graph and Q/DQ. - torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry so LayerNorm output quantizers are honored. - torch/quantization/plugins/huggingface: skip attention wrappers whose children are also "*Attention" to avoid double-patching eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention). Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8 config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer, disabled input quantizers on LayerNorm-followed layers, and *_bmm_quantizer entries) plus accuracy + TRT-latency comparison against an FP16 baseline. Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1): - Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs 80.96% / 95.44% (torch FP8) — -0.20% / -0.06% - TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 010b220 commit ef8c769

8 files changed

Lines changed: 424 additions & 41 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Changelog
1717
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
1818
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml>`_ for usage.
1919
- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning.
20+
- Add FP8 MHA quantization support for vision transformers. Adds an attention-aware ONNX post-processing pass (scale Mul / K-transpose move before Q, Q→DQ insertion on softmax output) in :class:`FP8QuantExporter <modelopt.onnx.export.fp8_exporter.FP8QuantExporter>`, per-instance nested-attention-wrapper skipping in the HF plugin, and ``nn.LayerNorm`` registration in ``QuantModuleRegistry`` so BMM input quantizers and LayerNorm output quantizers defined in FP8_DEFAULT_CFG are honored end-to-end. See `examples/torch_onnx/torch_quant_to_onnx.py <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/torch_onnx/torch_quant_to_onnx.py>`_ for the general timm-model quantize→ONNX workflow.
2021

2122
**Backward Breaking Changes**
2223

modelopt/onnx/export/fp8_exporter.py

Lines changed: 281 additions & 9 deletions
Large diffs are not rendered by default.

modelopt/onnx/utils.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,74 @@ def _bypass_cast_node(model: onnx.ModelProto, node: onnx.NodeProto) -> None:
14151415
consumer.input[i] = input_tensor
14161416

14171417

1418+
_DQ_OPS = {"DequantizeLinear", "TRT_FP8DequantizeLinear"}
1419+
_Q_OPS = {"QuantizeLinear", "TRT_FP8QuantizeLinear"}
1420+
1421+
1422+
def _scale_fp32_to_fp16(scale_init: onnx.TensorProto) -> None:
1423+
"""Convert a scalar Q/DQ scale initializer in-place from FP32 to FP16.
1424+
1425+
Warns if any non-zero scale saturates to 0/inf in FP16 (out of FP16 representable range).
1426+
"""
1427+
if scale_init.data_type != onnx.TensorProto.FLOAT:
1428+
return
1429+
scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32)
1430+
if not scale_data.size:
1431+
scale_data = np.array(scale_init.float_data, dtype=np.float32)
1432+
fp16_data = scale_data.astype(np.float16)
1433+
if np.any(np.isinf(fp16_data)) or (
1434+
np.any(fp16_data == 0) and np.any(scale_data != 0)
1435+
):
1436+
logger.warning(
1437+
f"Q/DQ scale '{scale_init.name}' overflows or underflows when cast to FP16"
1438+
)
1439+
scale_init.data_type = onnx.TensorProto.FLOAT16
1440+
scale_init.raw_data = fp16_data.tobytes()
1441+
del scale_init.float_data[:]
1442+
1443+
1444+
def fold_q_fp16_to_fp32_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
1445+
"""Remove ``Cast(FP16→FP32) → Q`` patterns inserted by ``convert_float_to_float16``.
1446+
1447+
The Q scale is rewritten to FP16 so Q consumes the FP16 graph directly. Skipped for
1448+
opsets below ``BASE_MIN_OPSET`` since FP16 Q scales require opset >= 19.
1449+
"""
1450+
if get_opset_version(onnx_model) < BASE_MIN_OPSET:
1451+
logger.debug(
1452+
f"Skipping fold_q_fp16_to_fp32_casts: opset < {BASE_MIN_OPSET} (FP16 Q scale unsupported)"
1453+
)
1454+
return onnx_model
1455+
1456+
consumer_map: dict[str, list[onnx.NodeProto]] = {}
1457+
for node in onnx_model.graph.node:
1458+
for inp in node.input:
1459+
consumer_map.setdefault(inp, []).append(node)
1460+
initializers = {init.name: init for init in onnx_model.graph.initializer}
1461+
1462+
to_remove = []
1463+
for node in onnx_model.graph.node:
1464+
if node.op_type != "Cast":
1465+
continue
1466+
cast_to = next((a.i for a in node.attribute if a.name == "to"), None)
1467+
if cast_to != onnx.TensorProto.FLOAT:
1468+
continue
1469+
consumers = consumer_map.get(node.output[0], [])
1470+
if not consumers or not all(c.op_type in _Q_OPS for c in consumers):
1471+
continue
1472+
1473+
for q_node in consumers:
1474+
if len(q_node.input) >= 2 and q_node.input[1] in initializers:
1475+
_scale_fp32_to_fp16(initializers[q_node.input[1]])
1476+
1477+
_bypass_cast_node(onnx_model, node)
1478+
to_remove.append(node)
1479+
1480+
logger.debug(f"Folded {len(to_remove)} Cast(FP16->FP32) -> Q patterns")
1481+
for node in to_remove:
1482+
onnx_model.graph.node.remove(node)
1483+
return onnx_model
1484+
1485+
14181486
def _is_foldable_constant_cast_pattern(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
14191487
"""Check if a Constant -> Cast pattern can be folded."""
14201488
assert node.op_type == "Cast"
@@ -1523,7 +1591,12 @@ def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
15231591
Returns:
15241592
The ONNX model with Cast nodes removed and DQ outputs set to FP16.
15251593
"""
1526-
import numpy as np
1594+
if get_opset_version(onnx_model) < BASE_MIN_OPSET:
1595+
logger.debug(
1596+
f"Skipping fold_dq_fp32_to_fp16_casts: opset < {BASE_MIN_OPSET} "
1597+
"(FP16 DQ scale unsupported)"
1598+
)
1599+
return onnx_model
15271600

15281601
dq_ops = {"DequantizeLinear", "TRT_FP8DequantizeLinear"}
15291602

@@ -1623,6 +1696,13 @@ def fold_qdq_scale_fp16_to_fp32_casts(onnx_model: onnx.ModelProto) -> onnx.Model
16231696
Returns:
16241697
The ONNX model with redundant scale-path casts removed.
16251698
"""
1699+
if get_opset_version(onnx_model) < BASE_MIN_OPSET:
1700+
logger.debug(
1701+
f"Skipping fold_qdq_scale_fp16_to_fp32_casts: opset < {BASE_MIN_OPSET} "
1702+
"(FP16 Q/DQ scale unsupported)"
1703+
)
1704+
return onnx_model
1705+
16261706
qdq_ops = {
16271707
"QuantizeLinear",
16281708
"DequantizeLinear",

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
change_casts_to_fp16,
4949
check_model_uses_external_data,
5050
fold_dq_fp32_to_fp16_casts,
51+
fold_q_fp16_to_fp32_casts,
5152
fold_qdq_scale_fp16_to_fp32_casts,
5253
get_input_names,
5354
get_input_shapes,
@@ -663,6 +664,11 @@ def get_onnx_bytes_and_metadata(
663664

664665
onnx_opt_graph = remove_redundant_casts(onnx_opt_graph)
665666

667+
# Remove Cast nodes around Q/DQ for optimal TRT fusion
668+
if is_fp8_quantized(model):
669+
onnx_opt_graph = fold_q_fp16_to_fp32_casts(onnx_opt_graph)
670+
onnx_opt_graph = fold_dq_fp32_to_fp16_casts(onnx_opt_graph)
671+
666672
# TensorRT expects all scales to be postive
667673
onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph)
668674

modelopt/torch/quantization/export_onnx.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -216,71 +216,54 @@ def _fp8_quantize(
216216
g: "GraphContext",
217217
inputs: torch.Value,
218218
scale_inv: float,
219-
trt_high_precision_dtype: str,
220219
):
221220
"""Helper Function for Quantization."""
221+
# Emit the scale in the native input dtype so no Cast is inserted between the
222+
# graph and Q/DQ (Cast nodes block TRT from fusing DQ into the MatMul kernel).
222223
output_shape = sym_help._get_tensor_sizes(inputs)
223-
224-
# TRT StronglyType only supports FP16 QDQs
225-
# custom ops, so cast the input if needed.
226-
input_type = inputs.type().scalarType()
227-
assert trt_high_precision_dtype in (input_type, "Float"), (
228-
"TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
229-
)
230-
if trt_high_precision_dtype != input_type:
231-
inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[trt_high_precision_dtype])
232-
233224
scale = g.op(
234225
"Constant",
235-
value_t=torch.tensor(scale_inv).to(torch_dtype_map[trt_high_precision_dtype]),
226+
value_t=torch.tensor(scale_inv).to(torch_dtype_map[inputs.type().scalarType()]),
236227
)
237-
q_op = g.op("trt::TRT_FP8QuantizeLinear", inputs, scale).setType(
228+
return g.op("trt::TRT_FP8QuantizeLinear", inputs, scale).setType(
238229
inputs.type().with_dtype(torch.uint8).with_sizes(output_shape)
239230
)
240-
return q_op
241231

242232

243233
def _fp8_dequantize(
244234
g: "GraphContext",
245235
inputs: torch.Value,
246236
scale_inv: float,
247-
trt_high_precision_dtype: str,
248237
otype: str | None = None,
249238
):
250239
"""Helper Function for Dequantization."""
251240
output_shape = sym_help._get_tensor_sizes(inputs)
252-
assert trt_high_precision_dtype in (otype, "Float"), (
253-
"TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
254-
)
255241
scale = g.op(
256242
"Constant",
257243
value_t=torch.tensor(scale_inv, dtype=torch_dtype_map[otype]), # type: ignore[index]
258244
)
259-
out = g.op("trt::TRT_FP8DequantizeLinear", inputs, scale).setType(
260-
inputs.type().with_dtype(torch_dtype_map[trt_high_precision_dtype]).with_sizes(output_shape)
245+
return g.op("trt::TRT_FP8DequantizeLinear", inputs, scale).setType(
246+
inputs.type().with_dtype(torch_dtype_map[otype]).with_sizes(output_shape) # type: ignore[index]
261247
)
262248

263-
# DQ outputs are currently constrained to FP32 due to a similar limitation in ORT
264-
# custom ops, so cast the output if needed.
265-
if trt_high_precision_dtype != otype:
266-
out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index]
267-
return out
268-
269249

270250
def export_fp8(
271251
g: "GraphContext",
272252
inputs: torch.Value,
273253
amax: float,
274254
trt_high_precision_dtype: str | None,
275255
):
276-
"""Export quantized model to FP8 ONNX."""
256+
"""Export quantized model to FP8 ONNX.
257+
258+
``trt_high_precision_dtype`` is accepted for API compatibility but unused: Q/DQ now
259+
emit scales in the native input dtype, so no intermediate Cast is required.
260+
"""
261+
del trt_high_precision_dtype
277262
scale = 1.0 if amax is None else 448.0 / float(amax)
278263
otype = inputs.type().scalarType()
279-
if trt_high_precision_dtype is None:
280-
trt_high_precision_dtype = otype
281264

282-
q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype)
283-
return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype)
265+
q_tensor = _fp8_quantize(g, inputs, 1.0 / scale)
266+
return _fp8_dequantize(g, q_tensor, 1.0 / scale, otype)
284267

285268

286269
def scaled_dot_product_attention(

modelopt/torch/quantization/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .modules.quant_batchnorm import *
2020
from .modules.quant_conv import *
2121
from .modules.quant_instancenorm import *
22+
from .modules.quant_layernorm import *
2223
from .modules.quant_linear import *
2324
from .modules.quant_module import *
2425
from .modules.quant_pooling import *
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Registers ``torch.nn.LayerNorm`` with ``QuantInputBase`` so its output quantizer is
17+
honored during quantization. Required for FP8 attention fusion where a single LayerNorm
18+
output QDQ is shared across all downstream Q/K/V/FC consumers (instead of repeating it
19+
on each input), which enables TRT to fuse DQ into the attention MatMul kernels."""
20+
21+
import torch.nn as nn
22+
23+
from .quant_module import QuantInputBase, QuantModuleRegistry
24+
25+
QuantModuleRegistry.register({nn.LayerNorm: "nn.LayerNorm"})(QuantInputBase)

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,24 @@ def register_hf_attentions_on_the_fly(model):
286286

287287
attention_cls = set()
288288
registered_attn_module = False
289+
290+
# Skip attention wrappers that contain a nested "Attention" child on this specific
291+
# instance (e.g. ViTAttention wraps ViTSelfAttention). Patching both would
292+
# double-quantize eager_attention_forward. Checked per-instance (not by class) so a
293+
# class reused as both wrapper and leaf is not dropped everywhere. In a 3-level
294+
# hierarchy (Outer → Middle → Inner), both Outer and Middle are treated as wrappers
295+
# and only Inner is registered.
296+
def _wraps_nested_attention(module):
297+
return any(
298+
child is not module and type(child).__name__.endswith("Attention")
299+
for _, child in module.named_modules()
300+
)
301+
289302
for name, module in model.named_modules():
290303
# Only register attention classes that are from Huggingface transformers
291304
if type(module).__name__.endswith("Attention"):
305+
if _wraps_nested_attention(module):
306+
continue
292307
attention_type = _QuantAttention.get_attn_type(module)
293308
# Add modules to be registered only if they arent already registered
294309
if (

0 commit comments

Comments
 (0)