Skip to content

Commit d6533ac

Browse files
committed
Add FP8 MHA quantization support for HuggingFace ViT
Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar transformer vision models) when exported to ONNX with FP8 Q/DQ. - fp8_exporter: rewrite attention-scaling Mul and K Transpose to the Q-side so DQ feeds MatMul directly, pre-transpose weight constants, insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype now matches the graph's float dtype to keep strongly-typed builds consistent. - onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16 inserts around Q/DQ by rewriting scale initializers to FP16, so TRT fuses DQ into the downstream GEMM/MatMul kernel. - torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native input dtype so no Cast is injected between graph and Q/DQ. - torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry so LayerNorm output quantizers are honored. - torch/quantization/plugins/huggingface: skip attention wrappers whose children are also "*Attention" to avoid double-patching eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention). Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8 config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer, disabled input quantizers on LayerNorm-followed layers, and *_bmm_quantizer entries) plus accuracy + TRT-latency comparison against an FP16 baseline. Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1): - Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs 80.96% / 95.44% (torch FP8) — -0.20% / -0.06% - TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 9f8188d commit d6533ac

8 files changed

Lines changed: 933 additions & 27 deletions

File tree

examples/torch_onnx/vit_mha_quantization.py

Lines changed: 538 additions & 0 deletions
Large diffs are not rendered by default.

modelopt/onnx/export/fp8_exporter.py

Lines changed: 254 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
6262
# Fold constants is required since the scale is not constant yet.
6363
graph.cleanup().toposort().fold_constants().cleanup()
6464

65+
n_t_folded = 0
66+
6567
for node in graph.nodes:
6668
if node.op == "TRT_FP8QuantizeLinear":
6769
# Should not remove input QDQ (only process weight quantization)
@@ -78,6 +80,33 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
7880
f"QDQ does not occur in pairs. You reached {dq_op.op}"
7981
)
8082

83+
# Pre-transpose constant weights if DQ feeds ``Transpose → MatMul`` (or
84+
# ``Cast → Transpose → MatMul`` after fp16 conversion) so TRT sees DQ→MatMul.
85+
transpose_to_remove = None
86+
cast_to_remove = None
87+
for candidate in list(dq_op.outputs[0].outputs):
88+
if candidate.op == "Cast":
89+
cast_to_remove = candidate
90+
candidate = next(
91+
(c for c in candidate.outputs[0].outputs if c.op == "Transpose"),
92+
None,
93+
)
94+
if candidate is None:
95+
cast_to_remove = None
96+
continue
97+
if candidate.op != "Transpose":
98+
cast_to_remove = None
99+
continue
100+
if any(c.op == "MatMul" for c in candidate.outputs[0].outputs):
101+
perm = candidate.attrs.get("perm", None)
102+
torch_weights = (
103+
torch_weights.permute(*perm).contiguous()
104+
if perm is not None
105+
else torch_weights.T.contiguous()
106+
)
107+
transpose_to_remove = candidate
108+
break
109+
81110
# Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8.
82111
numpy_weights = (
83112
(torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy()
@@ -94,20 +123,232 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
94123
dq_op.inputs[0] = onnx_weights_fp8
95124
dq_op.op = "DequantizeLinear"
96125
dq_op.outputs[0].dtype = dq_op.inputs[1].dtype
126+
dq_op.outputs[0].shape = list(numpy_weights.shape)
127+
128+
if transpose_to_remove is not None:
129+
t_out = transpose_to_remove.outputs[0]
130+
for consumer in list(t_out.outputs):
131+
for i, inp in enumerate(consumer.inputs):
132+
if inp is t_out:
133+
consumer.inputs[i] = dq_op.outputs[0]
134+
transpose_to_remove.outputs.clear()
135+
if cast_to_remove is not None:
136+
cast_to_remove.outputs.clear()
137+
n_t_folded += 1
97138

98139
graph.cleanup().toposort()
99140
end_time = time.time()
141+
if n_t_folded > 0:
142+
logger.info(f"Folded {n_t_folded} weight Transpose nodes during weight compression")
100143
print(f"fp8 qdq replaced with only dq completed in {end_time - start_time}s.")
101144

102145
return gs.export_onnx(graph)
103146

147+
@staticmethod
148+
def _move_mul_before_qdq(graph: gs.Graph) -> int:
149+
"""Move attention-scaling Mul(const) from after DQ to before Q for TRT MatMul fusion.
150+
151+
Handles both ``DQ → Mul → MatMul`` and ``DQ → Transpose → Mul → MatMul`` (K path).
152+
"""
153+
count = 0
154+
for mul_node in list(graph.nodes):
155+
if mul_node.op != "Mul":
156+
continue
157+
158+
const_input = next(
159+
(i for i in mul_node.inputs if isinstance(i, gs.Constant) and i.values.size == 1),
160+
None,
161+
)
162+
tensor_input = next(
163+
(i for i in mul_node.inputs if not isinstance(i, gs.Constant)), None
164+
)
165+
if const_input is None or tensor_input is None:
166+
continue
167+
if not (isinstance(tensor_input, gs.Variable) and len(tensor_input.inputs) == 1):
168+
continue
169+
170+
producer = tensor_input.inputs[0]
171+
transpose_node = producer if producer.op == "Transpose" else None
172+
dq_node = producer if producer.op == "DequantizeLinear" else None
173+
if transpose_node is not None:
174+
t_input = transpose_node.inputs[0]
175+
if (
176+
isinstance(t_input, gs.Variable)
177+
and len(t_input.inputs) == 1
178+
and t_input.inputs[0].op == "DequantizeLinear"
179+
):
180+
dq_node = t_input.inputs[0]
181+
if dq_node is None:
182+
continue
183+
184+
q_output = dq_node.inputs[0]
185+
if (
186+
not isinstance(q_output, gs.Variable)
187+
or len(q_output.inputs) != 1
188+
or q_output.inputs[0].op != "QuantizeLinear"
189+
):
190+
continue
191+
q_node = q_output.inputs[0]
192+
q_input = q_node.inputs[0]
193+
if not isinstance(q_input, gs.Variable):
194+
continue
195+
196+
mul_output = mul_node.outputs[0]
197+
if not any(c.op == "MatMul" for c in mul_output.outputs):
198+
continue
199+
200+
new_mul_output = gs.Variable(
201+
q_input.name + "_scaled", dtype=q_input.dtype, shape=q_input.shape
202+
)
203+
graph.nodes.append(
204+
gs.Node(
205+
op="Mul",
206+
name=mul_node.name + "_moved",
207+
inputs=[q_input, const_input],
208+
outputs=[new_mul_output],
209+
)
210+
)
211+
q_node.inputs[0] = new_mul_output
212+
213+
replacement = (
214+
transpose_node.outputs[0] if transpose_node is not None else dq_node.outputs[0]
215+
)
216+
for consumer in list(mul_output.outputs):
217+
for i, inp in enumerate(consumer.inputs):
218+
if inp is mul_output:
219+
consumer.inputs[i] = replacement
220+
mul_node.outputs.clear()
221+
count += 1
222+
223+
graph.cleanup().toposort()
224+
return count
225+
226+
@staticmethod
227+
def _move_transpose_before_qdq(graph: gs.Graph) -> int:
228+
"""Move Transpose from ``DQ → Transpose → MatMul`` to ``Transpose → Q → DQ → MatMul`` (K path)."""
229+
count = 0
230+
for transpose_node in list(graph.nodes):
231+
if transpose_node.op != "Transpose":
232+
continue
233+
234+
t_input = transpose_node.inputs[0]
235+
if (
236+
not isinstance(t_input, gs.Variable)
237+
or len(t_input.inputs) != 1
238+
or t_input.inputs[0].op != "DequantizeLinear"
239+
):
240+
continue
241+
dq_node = t_input.inputs[0]
242+
243+
dq_input = dq_node.inputs[0]
244+
if (
245+
not isinstance(dq_input, gs.Variable)
246+
or len(dq_input.inputs) != 1
247+
or dq_input.inputs[0].op != "QuantizeLinear"
248+
):
249+
continue
250+
q_node = dq_input.inputs[0]
251+
q_input = q_node.inputs[0]
252+
if not isinstance(q_input, gs.Variable):
253+
continue
254+
255+
t_output = transpose_node.outputs[0]
256+
if not any(c.op == "MatMul" for c in t_output.outputs):
257+
continue
258+
259+
new_t_output = gs.Variable(q_input.name + "_transposed", dtype=q_input.dtype)
260+
graph.nodes.append(
261+
gs.Node(
262+
op="Transpose",
263+
name=transpose_node.name + "_moved",
264+
inputs=[q_input],
265+
outputs=[new_t_output],
266+
attrs=transpose_node.attrs,
267+
)
268+
)
269+
q_node.inputs[0] = new_t_output
270+
271+
for consumer in list(t_output.outputs):
272+
for i, inp in enumerate(consumer.inputs):
273+
if inp is t_output:
274+
consumer.inputs[i] = dq_node.outputs[0]
275+
transpose_node.outputs.clear()
276+
count += 1
277+
278+
graph.cleanup().toposort()
279+
return count
280+
281+
@staticmethod
282+
def _insert_qdq_after_softmax(graph: gs.Graph) -> int:
283+
"""Insert FP8 Q→DQ on Softmax outputs feeding MatMul (required by TRT MHA fusion).
284+
285+
Torch export does not quantize softmax output; scale=1/448 saturates exactly at 1.0
286+
(softmax range is [0, 1]) while covering the full FP8 E4M3 representable range.
287+
"""
288+
import numpy as np
289+
290+
count = 0
291+
for softmax_node in list(graph.nodes):
292+
if softmax_node.op != "Softmax":
293+
continue
294+
softmax_output = softmax_node.outputs[0]
295+
if not any(c.op == "MatMul" for c in softmax_output.outputs):
296+
continue
297+
if any(c.op == "QuantizeLinear" for c in softmax_output.outputs):
298+
continue
299+
300+
# Match scale dtype to the graph's current float dtype so TRT stronglyTyped
301+
# sees consistent Q/DQ types with the surrounding compute.
302+
scale_dtype = softmax_output.dtype if softmax_output.dtype is not None else np.float32
303+
scale_val = np.array(1.0 / 448.0, dtype=scale_dtype)
304+
scale_constant = gs.Constant(softmax_node.name + "/softmax_q_scale", scale_val)
305+
dq_scale_constant = gs.Constant(
306+
softmax_node.name + "/softmax_dq_scale", scale_val.copy()
307+
)
308+
309+
zp_tensor = onnx.TensorProto()
310+
zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
311+
zp_tensor.dims.extend([1])
312+
zp_tensor.raw_data = b"\x00"
313+
zp_constant = gs.Constant(
314+
softmax_node.name + "/softmax_q_zero_point", LazyValues(zp_tensor)
315+
)
316+
317+
q_output = gs.Variable(softmax_node.name + "/q_output")
318+
dq_output = gs.Variable(softmax_node.name + "/dq_output", dtype=softmax_output.dtype)
319+
q_node = gs.Node(
320+
op="QuantizeLinear",
321+
name=softmax_node.name + "/QuantizeLinear",
322+
inputs=[softmax_output, scale_constant, zp_constant],
323+
outputs=[q_output],
324+
attrs={"saturate": 1},
325+
)
326+
dq_node = gs.Node(
327+
op="DequantizeLinear",
328+
name=softmax_node.name + "/DequantizeLinear",
329+
inputs=[q_output, dq_scale_constant],
330+
outputs=[dq_output],
331+
)
332+
graph.nodes.extend([q_node, dq_node])
333+
334+
for consumer in list(softmax_output.outputs):
335+
if consumer is q_node:
336+
continue
337+
for i, inp in enumerate(consumer.inputs):
338+
if inp is softmax_output:
339+
consumer.inputs[i] = dq_output
340+
count += 1
341+
342+
graph.cleanup().toposort()
343+
return count
344+
104345
@staticmethod
105346
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
106347
"""Post-processes the ONNX model for FP8 quantization.
107348
108-
Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear:
109-
- TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1
110-
- TRT_FP8DequantizeLinear -> DequantizeLinear
349+
Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear and
350+
rewrites attention scaling / K-transpose / softmax-output patterns so TRT
351+
can fuse DQ into the attention MatMul kernels.
111352
112353
Args:
113354
onnx_model: The ONNX model containing TRT_FP8 quantization nodes.
@@ -144,5 +385,15 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
144385
f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear"
145386
)
146387

388+
# Attention-aware rewrites so TRT can fuse DQ into the attention MatMuls.
389+
n_mul = FP8QuantExporter._move_mul_before_qdq(graph)
390+
n_t = FP8QuantExporter._move_transpose_before_qdq(graph)
391+
n_sm = FP8QuantExporter._insert_qdq_after_softmax(graph)
392+
if n_mul or n_t or n_sm:
393+
logger.info(
394+
f"Attention QDQ rewrites: moved {n_mul} Mul, {n_t} Transpose; "
395+
f"inserted QDQ on {n_sm} Softmax outputs"
396+
)
397+
147398
graph.cleanup().toposort()
148399
return gs.export_onnx(graph)

modelopt/onnx/utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,96 @@ def _bypass_cast_node(model: onnx.ModelProto, node: onnx.NodeProto) -> None:
14151415
consumer.input[i] = input_tensor
14161416

14171417

1418+
_DQ_OPS = {"DequantizeLinear", "TRT_FP8DequantizeLinear"}
1419+
_Q_OPS = {"QuantizeLinear", "TRT_FP8QuantizeLinear"}
1420+
1421+
1422+
def _scale_fp32_to_fp16(scale_init: onnx.TensorProto) -> None:
1423+
"""Convert a scalar Q/DQ scale initializer in-place from FP32 to FP16."""
1424+
import numpy as np
1425+
1426+
if scale_init.data_type != onnx.TensorProto.FLOAT:
1427+
return
1428+
scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32)
1429+
if not scale_data.size:
1430+
scale_data = np.array(scale_init.float_data, dtype=np.float32)
1431+
scale_init.data_type = onnx.TensorProto.FLOAT16
1432+
scale_init.raw_data = scale_data.astype(np.float16).tobytes()
1433+
del scale_init.float_data[:]
1434+
1435+
1436+
def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
1437+
"""Remove ``DQ → Cast(FP32→FP16)`` patterns inserted by ``convert_float_to_float16``.
1438+
1439+
The DQ scale is rewritten to FP16 so DQ natively produces FP16 output, enabling
1440+
TRT to fuse DQ directly into the downstream compute op.
1441+
"""
1442+
producer_map = {out: node for node in onnx_model.graph.node for out in node.output}
1443+
initializers = {init.name: init for init in onnx_model.graph.initializer}
1444+
1445+
to_remove = []
1446+
for node in onnx_model.graph.node:
1447+
if node.op_type != "Cast":
1448+
continue
1449+
cast_to = next((a.i for a in node.attribute if a.name == "to"), None)
1450+
if cast_to != onnx.TensorProto.FLOAT16:
1451+
continue
1452+
producer = producer_map.get(node.input[0])
1453+
if producer is None or producer.op_type not in _DQ_OPS:
1454+
continue
1455+
1456+
if len(producer.input) >= 2 and producer.input[1] in initializers:
1457+
_scale_fp32_to_fp16(initializers[producer.input[1]])
1458+
1459+
_bypass_cast_node(onnx_model, node)
1460+
to_remove.append(node)
1461+
1462+
for vi in onnx_model.graph.value_info:
1463+
if vi.name == producer.output[0]:
1464+
vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
1465+
break
1466+
1467+
logger.debug(f"Folded {len(to_remove)} DQ -> Cast(FP32->FP16) patterns")
1468+
for node in to_remove:
1469+
onnx_model.graph.node.remove(node)
1470+
return onnx_model
1471+
1472+
1473+
def fold_q_fp16_to_fp32_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
1474+
"""Remove ``Cast(FP16→FP32) → Q`` patterns inserted by ``convert_float_to_float16``.
1475+
1476+
The Q scale is rewritten to FP16 so Q consumes the FP16 graph directly.
1477+
"""
1478+
consumer_map: dict[str, list[onnx.NodeProto]] = {}
1479+
for node in onnx_model.graph.node:
1480+
for inp in node.input:
1481+
consumer_map.setdefault(inp, []).append(node)
1482+
initializers = {init.name: init for init in onnx_model.graph.initializer}
1483+
1484+
to_remove = []
1485+
for node in onnx_model.graph.node:
1486+
if node.op_type != "Cast":
1487+
continue
1488+
cast_to = next((a.i for a in node.attribute if a.name == "to"), None)
1489+
if cast_to != onnx.TensorProto.FLOAT:
1490+
continue
1491+
consumers = consumer_map.get(node.output[0], [])
1492+
if not consumers or not all(c.op_type in _Q_OPS for c in consumers):
1493+
continue
1494+
1495+
for q_node in consumers:
1496+
if len(q_node.input) >= 2 and q_node.input[1] in initializers:
1497+
_scale_fp32_to_fp16(initializers[q_node.input[1]])
1498+
1499+
_bypass_cast_node(onnx_model, node)
1500+
to_remove.append(node)
1501+
1502+
logger.debug(f"Folded {len(to_remove)} Cast(FP16->FP32) -> Q patterns")
1503+
for node in to_remove:
1504+
onnx_model.graph.node.remove(node)
1505+
return onnx_model
1506+
1507+
14181508
def _is_foldable_constant_cast_pattern(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
14191509
"""Check if a Constant -> Cast pattern can be folded."""
14201510
assert node.op_type == "Cast"

0 commit comments

Comments
 (0)