Commit 928d417
committed
Add FP8 MHA quantization support for HuggingFace ViT
Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar
transformer vision models) when exported to ONNX with FP8 Q/DQ.
- fp8_exporter: rewrite attention-scaling Mul and K Transpose to the
Q-side so DQ feeds MatMul directly, pre-transpose weight constants,
insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype
now matches the graph's float dtype to keep strongly-typed builds
consistent.
- onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16
inserts around Q/DQ by rewriting scale initializers to FP16, so TRT
fuses DQ into the downstream GEMM/MatMul kernel.
- torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native
input dtype so no Cast is injected between graph and Q/DQ.
- torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry
so LayerNorm output quantizers are honored.
- torch/quantization/plugins/huggingface: skip attention wrappers whose
children are also "*Attention" to avoid double-patching
eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention).
Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8
config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer,
disabled input quantizers on LayerNorm-followed layers, and
*_bmm_quantizer entries) plus accuracy + TRT-latency comparison
against an FP16 baseline.
Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1):
- Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs
80.96% / 95.44% (torch FP8) — -0.20% / -0.06%
- TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>1 parent 010b220 commit 928d417
10 files changed
Lines changed: 488 additions & 42 deletions
File tree
- examples/torch_onnx
- modelopt
- onnx
- export
- torch
- _deploy/utils
- quantization
- nn
- modules
- plugins
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
| 20 | + | |
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
88 | 88 | | |
89 | 89 | | |
90 | 90 | | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
91 | 121 | | |
92 | 122 | | |
93 | 123 | | |
| |||
102 | 132 | | |
103 | 133 | | |
104 | 134 | | |
| 135 | + | |
105 | 136 | | |
106 | 137 | | |
107 | 138 | | |
108 | 139 | | |
109 | | - | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
110 | 143 | | |
111 | 144 | | |
112 | 145 | | |
| |||
121 | 154 | | |
122 | 155 | | |
123 | 156 | | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
124 | 218 | | |
125 | 219 | | |
126 | 220 | | |
| |||
458 | 552 | | |
459 | 553 | | |
460 | 554 | | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
461 | 563 | | |
462 | 564 | | |
463 | 565 | | |
| |||
0 commit comments