@@ -62,6 +62,8 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
6262 # Fold constants is required since the scale is not constant yet.
6363 graph .cleanup ().toposort ().fold_constants ().cleanup ()
6464
65+ n_t_folded = 0
66+
6567 for node in graph .nodes :
6668 if node .op == "TRT_FP8QuantizeLinear" :
6769 # Should not remove input QDQ (only process weight quantization)
@@ -78,6 +80,33 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
7880 f"QDQ does not occur in pairs. You reached { dq_op .op } "
7981 )
8082
83+ # Pre-transpose constant weights if DQ feeds ``Transpose → MatMul`` (or
84+ # ``Cast → Transpose → MatMul`` after fp16 conversion) so TRT sees DQ→MatMul.
85+ transpose_to_remove = None
86+ cast_to_remove = None
87+ for candidate in list (dq_op .outputs [0 ].outputs ):
88+ if candidate .op == "Cast" :
89+ cast_to_remove = candidate
90+ candidate = next (
91+ (c for c in candidate .outputs [0 ].outputs if c .op == "Transpose" ),
92+ None ,
93+ )
94+ if candidate is None :
95+ cast_to_remove = None
96+ continue
97+ if candidate .op != "Transpose" :
98+ cast_to_remove = None
99+ continue
100+ if any (c .op == "MatMul" for c in candidate .outputs [0 ].outputs ):
101+ perm = candidate .attrs .get ("perm" , None )
102+ torch_weights = (
103+ torch_weights .permute (* perm ).contiguous ()
104+ if perm is not None
105+ else torch_weights .T .contiguous ()
106+ )
107+ transpose_to_remove = candidate
108+ break
109+
81110 # Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8.
82111 numpy_weights = (
83112 (torch_weights / torch_scale ).to (torch .float8_e4m3fn ).view (torch .uint8 ).numpy ()
@@ -94,9 +123,23 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
94123 dq_op .inputs [0 ] = onnx_weights_fp8
95124 dq_op .op = "DequantizeLinear"
96125 dq_op .outputs [0 ].dtype = dq_op .inputs [1 ].dtype
126+ dq_op .outputs [0 ].shape = list (numpy_weights .shape )
127+
128+ if transpose_to_remove is not None :
129+ t_out = transpose_to_remove .outputs [0 ]
130+ for consumer in list (t_out .outputs ):
131+ for i , inp in enumerate (consumer .inputs ):
132+ if inp is t_out :
133+ consumer .inputs [i ] = dq_op .outputs [0 ]
134+ transpose_to_remove .outputs .clear ()
135+ if cast_to_remove is not None :
136+ cast_to_remove .outputs .clear ()
137+ n_t_folded += 1
97138
98139 graph .cleanup ().toposort ()
99140 end_time = time .time ()
141+ if n_t_folded > 0 :
142+ logger .info (f"Folded { n_t_folded } weight Transpose nodes during weight compression" )
100143 print (f"fp8 qdq replaced with only dq completed in { end_time - start_time } s." )
101144
102145 return gs .export_onnx (graph )
@@ -175,13 +218,212 @@ def _quantize_conv_weights_to_fp8(graph: gs.Graph) -> int:
175218
176219 return count
177220
221+ @staticmethod
222+ def _move_mul_before_qdq (graph : gs .Graph ) -> int :
223+ """Move attention-scaling Mul(const) from after DQ to before Q for TRT MatMul fusion.
224+
225+ Handles both ``DQ → Mul → MatMul`` and ``DQ → Transpose → Mul → MatMul`` (K path).
226+ """
227+ count = 0
228+ for mul_node in list (graph .nodes ):
229+ if mul_node .op != "Mul" :
230+ continue
231+
232+ const_input = next (
233+ (i for i in mul_node .inputs if isinstance (i , gs .Constant ) and i .values .size == 1 ),
234+ None ,
235+ )
236+ tensor_input = next (
237+ (i for i in mul_node .inputs if not isinstance (i , gs .Constant )), None
238+ )
239+ if const_input is None or tensor_input is None :
240+ continue
241+ if not (isinstance (tensor_input , gs .Variable ) and len (tensor_input .inputs ) == 1 ):
242+ continue
243+
244+ producer = tensor_input .inputs [0 ]
245+ transpose_node = producer if producer .op == "Transpose" else None
246+ dq_node = producer if producer .op == "DequantizeLinear" else None
247+ if transpose_node is not None :
248+ t_input = transpose_node .inputs [0 ]
249+ if (
250+ isinstance (t_input , gs .Variable )
251+ and len (t_input .inputs ) == 1
252+ and t_input .inputs [0 ].op == "DequantizeLinear"
253+ ):
254+ dq_node = t_input .inputs [0 ]
255+ if dq_node is None :
256+ continue
257+
258+ q_output = dq_node .inputs [0 ]
259+ if (
260+ not isinstance (q_output , gs .Variable )
261+ or len (q_output .inputs ) != 1
262+ or q_output .inputs [0 ].op != "QuantizeLinear"
263+ ):
264+ continue
265+ q_node = q_output .inputs [0 ]
266+ q_input = q_node .inputs [0 ]
267+ if not isinstance (q_input , gs .Variable ):
268+ continue
269+
270+ mul_output = mul_node .outputs [0 ]
271+ if not any (c .op == "MatMul" for c in mul_output .outputs ):
272+ continue
273+
274+ new_mul_output = gs .Variable (
275+ q_input .name + "_scaled" , dtype = q_input .dtype , shape = q_input .shape
276+ )
277+ graph .nodes .append (
278+ gs .Node (
279+ op = "Mul" ,
280+ name = mul_node .name + "_moved" ,
281+ inputs = [q_input , const_input ],
282+ outputs = [new_mul_output ],
283+ )
284+ )
285+ q_node .inputs [0 ] = new_mul_output
286+
287+ replacement = (
288+ transpose_node .outputs [0 ] if transpose_node is not None else dq_node .outputs [0 ]
289+ )
290+ for consumer in list (mul_output .outputs ):
291+ for i , inp in enumerate (consumer .inputs ):
292+ if inp is mul_output :
293+ consumer .inputs [i ] = replacement
294+ mul_node .outputs .clear ()
295+ count += 1
296+
297+ graph .cleanup ().toposort ()
298+ return count
299+
300+ @staticmethod
301+ def _move_transpose_before_qdq (graph : gs .Graph ) -> int :
302+ """Move Transpose from ``DQ → Transpose → MatMul`` to ``Transpose → Q → DQ → MatMul`` (K path)."""
303+ count = 0
304+ for transpose_node in list (graph .nodes ):
305+ if transpose_node .op != "Transpose" :
306+ continue
307+
308+ t_input = transpose_node .inputs [0 ]
309+ if (
310+ not isinstance (t_input , gs .Variable )
311+ or len (t_input .inputs ) != 1
312+ or t_input .inputs [0 ].op != "DequantizeLinear"
313+ ):
314+ continue
315+ dq_node = t_input .inputs [0 ]
316+
317+ dq_input = dq_node .inputs [0 ]
318+ if (
319+ not isinstance (dq_input , gs .Variable )
320+ or len (dq_input .inputs ) != 1
321+ or dq_input .inputs [0 ].op != "QuantizeLinear"
322+ ):
323+ continue
324+ q_node = dq_input .inputs [0 ]
325+ q_input = q_node .inputs [0 ]
326+ if not isinstance (q_input , gs .Variable ):
327+ continue
328+
329+ t_output = transpose_node .outputs [0 ]
330+ if not any (c .op == "MatMul" for c in t_output .outputs ):
331+ continue
332+
333+ new_t_output = gs .Variable (q_input .name + "_transposed" , dtype = q_input .dtype )
334+ graph .nodes .append (
335+ gs .Node (
336+ op = "Transpose" ,
337+ name = transpose_node .name + "_moved" ,
338+ inputs = [q_input ],
339+ outputs = [new_t_output ],
340+ attrs = transpose_node .attrs ,
341+ )
342+ )
343+ q_node .inputs [0 ] = new_t_output
344+
345+ for consumer in list (t_output .outputs ):
346+ for i , inp in enumerate (consumer .inputs ):
347+ if inp is t_output :
348+ consumer .inputs [i ] = dq_node .outputs [0 ]
349+ transpose_node .outputs .clear ()
350+ count += 1
351+
352+ graph .cleanup ().toposort ()
353+ return count
354+
355+ @staticmethod
356+ def _insert_qdq_after_softmax (graph : gs .Graph ) -> int :
357+ """Insert FP8 Q→DQ on Softmax outputs feeding MatMul (required by TRT MHA fusion).
358+
359+ Torch export does not quantize softmax output; scale=1/448 saturates exactly at 1.0
360+ (softmax range is [0, 1]) while covering the full FP8 E4M3 representable range.
361+ """
362+ import numpy as np
363+
364+ count = 0
365+ for softmax_node in list (graph .nodes ):
366+ if softmax_node .op != "Softmax" :
367+ continue
368+ softmax_output = softmax_node .outputs [0 ]
369+ if not any (c .op == "MatMul" for c in softmax_output .outputs ):
370+ continue
371+ if any (c .op == "QuantizeLinear" for c in softmax_output .outputs ):
372+ continue
373+
374+ # Match scale dtype to the graph's current float dtype so TRT stronglyTyped
375+ # sees consistent Q/DQ types with the surrounding compute.
376+ scale_dtype = softmax_output .dtype if softmax_output .dtype is not None else np .float32
377+ scale_val = np .array (1.0 / 448.0 , dtype = scale_dtype )
378+ scale_constant = gs .Constant (softmax_node .name + "/softmax_q_scale" , scale_val )
379+ dq_scale_constant = gs .Constant (
380+ softmax_node .name + "/softmax_dq_scale" , scale_val .copy ()
381+ )
382+
383+ zp_tensor = onnx .TensorProto ()
384+ zp_tensor .data_type = onnx .TensorProto .FLOAT8E4M3FN
385+ zp_tensor .dims .extend ([1 ])
386+ zp_tensor .raw_data = b"\x00 "
387+ zp_constant = gs .Constant (
388+ softmax_node .name + "/softmax_q_zero_point" , LazyValues (zp_tensor )
389+ )
390+
391+ q_output = gs .Variable (softmax_node .name + "/q_output" )
392+ dq_output = gs .Variable (softmax_node .name + "/dq_output" , dtype = softmax_output .dtype )
393+ q_node = gs .Node (
394+ op = "QuantizeLinear" ,
395+ name = softmax_node .name + "/QuantizeLinear" ,
396+ inputs = [softmax_output , scale_constant , zp_constant ],
397+ outputs = [q_output ],
398+ attrs = {"saturate" : 1 },
399+ )
400+ dq_node = gs .Node (
401+ op = "DequantizeLinear" ,
402+ name = softmax_node .name + "/DequantizeLinear" ,
403+ inputs = [q_output , dq_scale_constant ],
404+ outputs = [dq_output ],
405+ )
406+ graph .nodes .extend ([q_node , dq_node ])
407+
408+ for consumer in list (softmax_output .outputs ):
409+ if consumer is q_node :
410+ continue
411+ for i , inp in enumerate (consumer .inputs ):
412+ if inp is softmax_output :
413+ consumer .inputs [i ] = dq_output
414+ count += 1
415+
416+ graph .cleanup ().toposort ()
417+ return count
418+
178419 @staticmethod
179420 def post_process (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
180421 """Post-processes the ONNX model for FP8 quantization.
181422
182- Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear and
423+ Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear,
183424 adds FP8 weight DQ for Conv layers whose weight quantizers were disabled during
184- TorchScript export.
425+ TorchScript export, and rewrites attention scaling / K-transpose / softmax-output
426+ patterns so TRT can fuse DQ into the attention MatMul kernels.
185427
186428 Args:
187429 onnx_model: The ONNX model containing TRT_FP8 quantization nodes.
@@ -223,5 +465,15 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
223465 if count > 0 :
224466 logger .info (f"Inserted FP8 weight DequantizeLinear for { count } Conv nodes" )
225467
468+ # Attention-aware rewrites so TRT can fuse DQ into the attention MatMuls.
469+ n_mul = FP8QuantExporter ._move_mul_before_qdq (graph )
470+ n_t = FP8QuantExporter ._move_transpose_before_qdq (graph )
471+ n_sm = FP8QuantExporter ._insert_qdq_after_softmax (graph )
472+ if n_mul or n_t or n_sm :
473+ logger .info (
474+ f"Attention QDQ rewrites: moved { n_mul } Mul, { n_t } Transpose; "
475+ f"inserted QDQ on { n_sm } Softmax outputs"
476+ )
477+
226478 graph .cleanup ().toposort ()
227479 return gs .export_onnx (graph )
0 commit comments