Skip to content

Commit c436553

Browse files
committed
Add FP8 MHA quantization support for HuggingFace ViT
Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar transformer vision models) when exported to ONNX with FP8 Q/DQ. - fp8_exporter: rewrite attention-scaling Mul and K Transpose to the Q-side so DQ feeds MatMul directly, pre-transpose weight constants, insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype now matches the graph's float dtype to keep strongly-typed builds consistent. - onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16 inserts around Q/DQ by rewriting scale initializers to FP16, so TRT fuses DQ into the downstream GEMM/MatMul kernel. - torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native input dtype so no Cast is injected between graph and Q/DQ. - torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry so LayerNorm output quantizers are honored. - torch/quantization/plugins/huggingface: skip attention wrappers whose children are also "*Attention" to avoid double-patching eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention). Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8 config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer, disabled input quantizers on LayerNorm-followed layers, and *_bmm_quantizer entries) plus accuracy + TRT-latency comparison against an FP16 baseline. Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1): - Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs 80.96% / 95.44% (torch FP8) — -0.20% / -0.06% - TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 010b220 commit c436553

8 files changed

Lines changed: 895 additions & 26 deletions

File tree

examples/torch_onnx/vit_mha_quantization.py

Lines changed: 538 additions & 0 deletions
Large diffs are not rendered by default.

modelopt/onnx/export/fp8_exporter.py

Lines changed: 254 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
6262
# Fold constants is required since the scale is not constant yet.
6363
graph.cleanup().toposort().fold_constants().cleanup()
6464

65+
n_t_folded = 0
66+
6567
for node in graph.nodes:
6668
if node.op == "TRT_FP8QuantizeLinear":
6769
# Should not remove input QDQ (only process weight quantization)
@@ -78,6 +80,33 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
7880
f"QDQ does not occur in pairs. You reached {dq_op.op}"
7981
)
8082

83+
# Pre-transpose constant weights if DQ feeds ``Transpose → MatMul`` (or
84+
# ``Cast → Transpose → MatMul`` after fp16 conversion) so TRT sees DQ→MatMul.
85+
transpose_to_remove = None
86+
cast_to_remove = None
87+
for candidate in list(dq_op.outputs[0].outputs):
88+
if candidate.op == "Cast":
89+
cast_to_remove = candidate
90+
candidate = next(
91+
(c for c in candidate.outputs[0].outputs if c.op == "Transpose"),
92+
None,
93+
)
94+
if candidate is None:
95+
cast_to_remove = None
96+
continue
97+
if candidate.op != "Transpose":
98+
cast_to_remove = None
99+
continue
100+
if any(c.op == "MatMul" for c in candidate.outputs[0].outputs):
101+
perm = candidate.attrs.get("perm", None)
102+
torch_weights = (
103+
torch_weights.permute(*perm).contiguous()
104+
if perm is not None
105+
else torch_weights.T.contiguous()
106+
)
107+
transpose_to_remove = candidate
108+
break
109+
81110
# Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8.
82111
numpy_weights = (
83112
(torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy()
@@ -94,9 +123,23 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
94123
dq_op.inputs[0] = onnx_weights_fp8
95124
dq_op.op = "DequantizeLinear"
96125
dq_op.outputs[0].dtype = dq_op.inputs[1].dtype
126+
dq_op.outputs[0].shape = list(numpy_weights.shape)
127+
128+
if transpose_to_remove is not None:
129+
t_out = transpose_to_remove.outputs[0]
130+
for consumer in list(t_out.outputs):
131+
for i, inp in enumerate(consumer.inputs):
132+
if inp is t_out:
133+
consumer.inputs[i] = dq_op.outputs[0]
134+
transpose_to_remove.outputs.clear()
135+
if cast_to_remove is not None:
136+
cast_to_remove.outputs.clear()
137+
n_t_folded += 1
97138

98139
graph.cleanup().toposort()
99140
end_time = time.time()
141+
if n_t_folded > 0:
142+
logger.info(f"Folded {n_t_folded} weight Transpose nodes during weight compression")
100143
print(f"fp8 qdq replaced with only dq completed in {end_time - start_time}s.")
101144

102145
return gs.export_onnx(graph)
@@ -175,13 +218,212 @@ def _quantize_conv_weights_to_fp8(graph: gs.Graph) -> int:
175218

176219
return count
177220

221+
@staticmethod
222+
def _move_mul_before_qdq(graph: gs.Graph) -> int:
223+
"""Move attention-scaling Mul(const) from after DQ to before Q for TRT MatMul fusion.
224+
225+
Handles both ``DQ → Mul → MatMul`` and ``DQ → Transpose → Mul → MatMul`` (K path).
226+
"""
227+
count = 0
228+
for mul_node in list(graph.nodes):
229+
if mul_node.op != "Mul":
230+
continue
231+
232+
const_input = next(
233+
(i for i in mul_node.inputs if isinstance(i, gs.Constant) and i.values.size == 1),
234+
None,
235+
)
236+
tensor_input = next(
237+
(i for i in mul_node.inputs if not isinstance(i, gs.Constant)), None
238+
)
239+
if const_input is None or tensor_input is None:
240+
continue
241+
if not (isinstance(tensor_input, gs.Variable) and len(tensor_input.inputs) == 1):
242+
continue
243+
244+
producer = tensor_input.inputs[0]
245+
transpose_node = producer if producer.op == "Transpose" else None
246+
dq_node = producer if producer.op == "DequantizeLinear" else None
247+
if transpose_node is not None:
248+
t_input = transpose_node.inputs[0]
249+
if (
250+
isinstance(t_input, gs.Variable)
251+
and len(t_input.inputs) == 1
252+
and t_input.inputs[0].op == "DequantizeLinear"
253+
):
254+
dq_node = t_input.inputs[0]
255+
if dq_node is None:
256+
continue
257+
258+
q_output = dq_node.inputs[0]
259+
if (
260+
not isinstance(q_output, gs.Variable)
261+
or len(q_output.inputs) != 1
262+
or q_output.inputs[0].op != "QuantizeLinear"
263+
):
264+
continue
265+
q_node = q_output.inputs[0]
266+
q_input = q_node.inputs[0]
267+
if not isinstance(q_input, gs.Variable):
268+
continue
269+
270+
mul_output = mul_node.outputs[0]
271+
if not any(c.op == "MatMul" for c in mul_output.outputs):
272+
continue
273+
274+
new_mul_output = gs.Variable(
275+
q_input.name + "_scaled", dtype=q_input.dtype, shape=q_input.shape
276+
)
277+
graph.nodes.append(
278+
gs.Node(
279+
op="Mul",
280+
name=mul_node.name + "_moved",
281+
inputs=[q_input, const_input],
282+
outputs=[new_mul_output],
283+
)
284+
)
285+
q_node.inputs[0] = new_mul_output
286+
287+
replacement = (
288+
transpose_node.outputs[0] if transpose_node is not None else dq_node.outputs[0]
289+
)
290+
for consumer in list(mul_output.outputs):
291+
for i, inp in enumerate(consumer.inputs):
292+
if inp is mul_output:
293+
consumer.inputs[i] = replacement
294+
mul_node.outputs.clear()
295+
count += 1
296+
297+
graph.cleanup().toposort()
298+
return count
299+
300+
@staticmethod
301+
def _move_transpose_before_qdq(graph: gs.Graph) -> int:
302+
"""Move Transpose from ``DQ → Transpose → MatMul`` to ``Transpose → Q → DQ → MatMul`` (K path)."""
303+
count = 0
304+
for transpose_node in list(graph.nodes):
305+
if transpose_node.op != "Transpose":
306+
continue
307+
308+
t_input = transpose_node.inputs[0]
309+
if (
310+
not isinstance(t_input, gs.Variable)
311+
or len(t_input.inputs) != 1
312+
or t_input.inputs[0].op != "DequantizeLinear"
313+
):
314+
continue
315+
dq_node = t_input.inputs[0]
316+
317+
dq_input = dq_node.inputs[0]
318+
if (
319+
not isinstance(dq_input, gs.Variable)
320+
or len(dq_input.inputs) != 1
321+
or dq_input.inputs[0].op != "QuantizeLinear"
322+
):
323+
continue
324+
q_node = dq_input.inputs[0]
325+
q_input = q_node.inputs[0]
326+
if not isinstance(q_input, gs.Variable):
327+
continue
328+
329+
t_output = transpose_node.outputs[0]
330+
if not any(c.op == "MatMul" for c in t_output.outputs):
331+
continue
332+
333+
new_t_output = gs.Variable(q_input.name + "_transposed", dtype=q_input.dtype)
334+
graph.nodes.append(
335+
gs.Node(
336+
op="Transpose",
337+
name=transpose_node.name + "_moved",
338+
inputs=[q_input],
339+
outputs=[new_t_output],
340+
attrs=transpose_node.attrs,
341+
)
342+
)
343+
q_node.inputs[0] = new_t_output
344+
345+
for consumer in list(t_output.outputs):
346+
for i, inp in enumerate(consumer.inputs):
347+
if inp is t_output:
348+
consumer.inputs[i] = dq_node.outputs[0]
349+
transpose_node.outputs.clear()
350+
count += 1
351+
352+
graph.cleanup().toposort()
353+
return count
354+
355+
@staticmethod
356+
def _insert_qdq_after_softmax(graph: gs.Graph) -> int:
357+
"""Insert FP8 Q→DQ on Softmax outputs feeding MatMul (required by TRT MHA fusion).
358+
359+
Torch export does not quantize softmax output; scale=1/448 saturates exactly at 1.0
360+
(softmax range is [0, 1]) while covering the full FP8 E4M3 representable range.
361+
"""
362+
import numpy as np
363+
364+
count = 0
365+
for softmax_node in list(graph.nodes):
366+
if softmax_node.op != "Softmax":
367+
continue
368+
softmax_output = softmax_node.outputs[0]
369+
if not any(c.op == "MatMul" for c in softmax_output.outputs):
370+
continue
371+
if any(c.op == "QuantizeLinear" for c in softmax_output.outputs):
372+
continue
373+
374+
# Match scale dtype to the graph's current float dtype so TRT stronglyTyped
375+
# sees consistent Q/DQ types with the surrounding compute.
376+
scale_dtype = softmax_output.dtype if softmax_output.dtype is not None else np.float32
377+
scale_val = np.array(1.0 / 448.0, dtype=scale_dtype)
378+
scale_constant = gs.Constant(softmax_node.name + "/softmax_q_scale", scale_val)
379+
dq_scale_constant = gs.Constant(
380+
softmax_node.name + "/softmax_dq_scale", scale_val.copy()
381+
)
382+
383+
zp_tensor = onnx.TensorProto()
384+
zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
385+
zp_tensor.dims.extend([1])
386+
zp_tensor.raw_data = b"\x00"
387+
zp_constant = gs.Constant(
388+
softmax_node.name + "/softmax_q_zero_point", LazyValues(zp_tensor)
389+
)
390+
391+
q_output = gs.Variable(softmax_node.name + "/q_output")
392+
dq_output = gs.Variable(softmax_node.name + "/dq_output", dtype=softmax_output.dtype)
393+
q_node = gs.Node(
394+
op="QuantizeLinear",
395+
name=softmax_node.name + "/QuantizeLinear",
396+
inputs=[softmax_output, scale_constant, zp_constant],
397+
outputs=[q_output],
398+
attrs={"saturate": 1},
399+
)
400+
dq_node = gs.Node(
401+
op="DequantizeLinear",
402+
name=softmax_node.name + "/DequantizeLinear",
403+
inputs=[q_output, dq_scale_constant],
404+
outputs=[dq_output],
405+
)
406+
graph.nodes.extend([q_node, dq_node])
407+
408+
for consumer in list(softmax_output.outputs):
409+
if consumer is q_node:
410+
continue
411+
for i, inp in enumerate(consumer.inputs):
412+
if inp is softmax_output:
413+
consumer.inputs[i] = dq_output
414+
count += 1
415+
416+
graph.cleanup().toposort()
417+
return count
418+
178419
@staticmethod
179420
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
180421
"""Post-processes the ONNX model for FP8 quantization.
181422
182-
Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear and
423+
Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear,
183424
adds FP8 weight DQ for Conv layers whose weight quantizers were disabled during
184-
TorchScript export.
425+
TorchScript export, and rewrites attention scaling / K-transpose / softmax-output
426+
patterns so TRT can fuse DQ into the attention MatMul kernels.
185427
186428
Args:
187429
onnx_model: The ONNX model containing TRT_FP8 quantization nodes.
@@ -223,5 +465,15 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
223465
if count > 0:
224466
logger.info(f"Inserted FP8 weight DequantizeLinear for {count} Conv nodes")
225467

468+
# Attention-aware rewrites so TRT can fuse DQ into the attention MatMuls.
469+
n_mul = FP8QuantExporter._move_mul_before_qdq(graph)
470+
n_t = FP8QuantExporter._move_transpose_before_qdq(graph)
471+
n_sm = FP8QuantExporter._insert_qdq_after_softmax(graph)
472+
if n_mul or n_t or n_sm:
473+
logger.info(
474+
f"Attention QDQ rewrites: moved {n_mul} Mul, {n_t} Transpose; "
475+
f"inserted QDQ on {n_sm} Softmax outputs"
476+
)
477+
226478
graph.cleanup().toposort()
227479
return gs.export_onnx(graph)

modelopt/onnx/utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,59 @@ def _bypass_cast_node(model: onnx.ModelProto, node: onnx.NodeProto) -> None:
14151415
consumer.input[i] = input_tensor
14161416

14171417

1418+
_DQ_OPS = {"DequantizeLinear", "TRT_FP8DequantizeLinear"}
1419+
_Q_OPS = {"QuantizeLinear", "TRT_FP8QuantizeLinear"}
1420+
1421+
1422+
def _scale_fp32_to_fp16(scale_init: onnx.TensorProto) -> None:
1423+
"""Convert a scalar Q/DQ scale initializer in-place from FP32 to FP16."""
1424+
import numpy as np
1425+
1426+
if scale_init.data_type != onnx.TensorProto.FLOAT:
1427+
return
1428+
scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32)
1429+
if not scale_data.size:
1430+
scale_data = np.array(scale_init.float_data, dtype=np.float32)
1431+
scale_init.data_type = onnx.TensorProto.FLOAT16
1432+
scale_init.raw_data = scale_data.astype(np.float16).tobytes()
1433+
del scale_init.float_data[:]
1434+
1435+
1436+
def fold_q_fp16_to_fp32_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
1437+
"""Remove ``Cast(FP16→FP32) → Q`` patterns inserted by ``convert_float_to_float16``.
1438+
1439+
The Q scale is rewritten to FP16 so Q consumes the FP16 graph directly.
1440+
"""
1441+
consumer_map: dict[str, list[onnx.NodeProto]] = {}
1442+
for node in onnx_model.graph.node:
1443+
for inp in node.input:
1444+
consumer_map.setdefault(inp, []).append(node)
1445+
initializers = {init.name: init for init in onnx_model.graph.initializer}
1446+
1447+
to_remove = []
1448+
for node in onnx_model.graph.node:
1449+
if node.op_type != "Cast":
1450+
continue
1451+
cast_to = next((a.i for a in node.attribute if a.name == "to"), None)
1452+
if cast_to != onnx.TensorProto.FLOAT:
1453+
continue
1454+
consumers = consumer_map.get(node.output[0], [])
1455+
if not consumers or not all(c.op_type in _Q_OPS for c in consumers):
1456+
continue
1457+
1458+
for q_node in consumers:
1459+
if len(q_node.input) >= 2 and q_node.input[1] in initializers:
1460+
_scale_fp32_to_fp16(initializers[q_node.input[1]])
1461+
1462+
_bypass_cast_node(onnx_model, node)
1463+
to_remove.append(node)
1464+
1465+
logger.debug(f"Folded {len(to_remove)} Cast(FP16->FP32) -> Q patterns")
1466+
for node in to_remove:
1467+
onnx_model.graph.node.remove(node)
1468+
return onnx_model
1469+
1470+
14181471
def _is_foldable_constant_cast_pattern(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
14191472
"""Check if a Constant -> Cast pattern can be folded."""
14201473
assert node.op_type == "Cast"

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
change_casts_to_fp16,
4949
check_model_uses_external_data,
5050
fold_dq_fp32_to_fp16_casts,
51+
fold_q_fp16_to_fp32_casts,
5152
fold_qdq_scale_fp16_to_fp32_casts,
5253
get_input_names,
5354
get_input_shapes,
@@ -663,6 +664,11 @@ def get_onnx_bytes_and_metadata(
663664

664665
onnx_opt_graph = remove_redundant_casts(onnx_opt_graph)
665666

667+
# Remove Cast nodes around Q/DQ for optimal TRT fusion
668+
if is_fp8_quantized(model):
669+
onnx_opt_graph = fold_q_fp16_to_fp32_casts(onnx_opt_graph)
670+
onnx_opt_graph = fold_dq_fp32_to_fp16_casts(onnx_opt_graph)
671+
666672
# TensorRT expects all scales to be postive
667673
onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph)
668674

0 commit comments

Comments
 (0)