You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add FP8 MHA quantization support for HuggingFace ViT
Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar
transformer vision models) when exported to ONNX with FP8 Q/DQ.
- fp8_exporter: rewrite attention-scaling Mul and K Transpose to the
Q-side so DQ feeds MatMul directly, pre-transpose weight constants,
insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype
now matches the graph's float dtype to keep strongly-typed builds
consistent.
- onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16
inserts around Q/DQ by rewriting scale initializers to FP16, so TRT
fuses DQ into the downstream GEMM/MatMul kernel.
- torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native
input dtype so no Cast is injected between graph and Q/DQ.
- torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry
so LayerNorm output quantizers are honored.
- torch/quantization/plugins/huggingface: skip attention wrappers whose
children are also "*Attention" to avoid double-patching
eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention).
Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8
config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer,
disabled input quantizers on LayerNorm-followed layers, and
*_bmm_quantizer entries) plus accuracy + TRT-latency comparison
against an FP16 baseline.
Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1):
- Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs
80.96% / 95.44% (torch FP8) — -0.20% / -0.06%
- TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Copy file name to clipboardExpand all lines: CHANGELOG.rst
+1Lines changed: 1 addition & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,6 +17,7 @@ Changelog
17
17
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
18
18
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml>`_ for usage.
19
19
- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning.
20
+
- Add FP8 MHA quantization support for vision transformers. Adds an attention-aware ONNX post-processing pass (scale Mul / K-transpose move before Q, Q→DQ insertion on softmax output) in :class:`FP8QuantExporter <modelopt.onnx.export.fp8_exporter.FP8QuantExporter>`, per-instance nested-attention-wrapper skipping in the HF plugin, and ``nn.LayerNorm`` registration in ``QuantModuleRegistry`` so BMM input quantizers and LayerNorm output quantizers defined in FP8_DEFAULT_CFG are honored end-to-end. See `examples/torch_onnx/torch_quant_to_onnx.py <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/torch_onnx/torch_quant_to_onnx.py>`_ for the general timm-model quantize→ONNX workflow.
0 commit comments