Commit f24abd4
committed
Add FP8 MHA quantization support for HuggingFace ViT
Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar
transformer vision models) when exported to ONNX with FP8 Q/DQ.
- fp8_exporter: rewrite attention-scaling Mul and K Transpose to the
Q-side so DQ feeds MatMul directly, pre-transpose weight constants,
insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype
now matches the graph's float dtype to keep strongly-typed builds
consistent.
- onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16
inserts around Q/DQ by rewriting scale initializers to FP16, so TRT
fuses DQ into the downstream GEMM/MatMul kernel.
- torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native
input dtype so no Cast is injected between graph and Q/DQ.
- torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry
so LayerNorm output quantizers are honored.
- torch/quantization/plugins/huggingface: skip attention wrappers whose
children are also "*Attention" to avoid double-patching
eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention).
Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8
config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer,
disabled input quantizers on LayerNorm-followed layers, and
*_bmm_quantizer entries) plus accuracy + TRT-latency comparison
against an FP16 baseline.
Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1):
- Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs
80.96% / 95.44% (torch FP8) — -0.20% / -0.06%
- TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>1 parent 010b220 commit f24abd4
9 files changed
Lines changed: 448 additions & 42 deletions
File tree
- examples/torch_onnx
- modelopt
- onnx
- export
- torch
- _deploy/utils
- quantization
- nn
- modules
- plugins
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
| 20 | + | |
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
88 | 88 | | |
89 | 89 | | |
90 | 90 | | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
91 | 107 | | |
92 | 108 | | |
93 | 109 | | |
| |||
102 | 118 | | |
103 | 119 | | |
104 | 120 | | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
105 | 124 | | |
106 | 125 | | |
107 | 126 | | |
108 | 127 | | |
109 | | - | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
110 | 131 | | |
111 | 132 | | |
112 | 133 | | |
| |||
458 | 479 | | |
459 | 480 | | |
460 | 481 | | |
| 482 | + | |
461 | 483 | | |
462 | 484 | | |
463 | 485 | | |
| |||
0 commit comments