@@ -62,6 +62,8 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
6262 # Fold constants is required since the scale is not constant yet.
6363 graph .cleanup ().toposort ().fold_constants ().cleanup ()
6464
65+ n_t_folded = 0
66+
6567 for node in graph .nodes :
6668 if node .op == "TRT_FP8QuantizeLinear" :
6769 # Should not remove input QDQ (only process weight quantization)
@@ -78,6 +80,33 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
7880 f"QDQ does not occur in pairs. You reached { dq_op .op } "
7981 )
8082
83+ # Pre-transpose constant weights if DQ feeds ``Transpose → MatMul`` (or
84+ # ``Cast → Transpose → MatMul`` after fp16 conversion) so TRT sees DQ→MatMul.
85+ transpose_to_remove = None
86+ cast_to_remove = None
87+ for candidate in list (dq_op .outputs [0 ].outputs ):
88+ if candidate .op == "Cast" :
89+ cast_to_remove = candidate
90+ candidate = next (
91+ (c for c in candidate .outputs [0 ].outputs if c .op == "Transpose" ),
92+ None ,
93+ )
94+ if candidate is None :
95+ cast_to_remove = None
96+ continue
97+ if candidate .op != "Transpose" :
98+ cast_to_remove = None
99+ continue
100+ if any (c .op == "MatMul" for c in candidate .outputs [0 ].outputs ):
101+ perm = candidate .attrs .get ("perm" , None )
102+ torch_weights = (
103+ torch_weights .permute (* perm ).contiguous ()
104+ if perm is not None
105+ else torch_weights .T .contiguous ()
106+ )
107+ transpose_to_remove = candidate
108+ break
109+
81110 # Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8.
82111 numpy_weights = (
83112 (torch_weights / torch_scale ).to (torch .float8_e4m3fn ).view (torch .uint8 ).numpy ()
@@ -94,20 +123,232 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
94123 dq_op .inputs [0 ] = onnx_weights_fp8
95124 dq_op .op = "DequantizeLinear"
96125 dq_op .outputs [0 ].dtype = dq_op .inputs [1 ].dtype
126+ dq_op .outputs [0 ].shape = list (numpy_weights .shape )
127+
128+ if transpose_to_remove is not None :
129+ t_out = transpose_to_remove .outputs [0 ]
130+ for consumer in list (t_out .outputs ):
131+ for i , inp in enumerate (consumer .inputs ):
132+ if inp is t_out :
133+ consumer .inputs [i ] = dq_op .outputs [0 ]
134+ transpose_to_remove .outputs .clear ()
135+ if cast_to_remove is not None :
136+ cast_to_remove .outputs .clear ()
137+ n_t_folded += 1
97138
98139 graph .cleanup ().toposort ()
99140 end_time = time .time ()
141+ if n_t_folded > 0 :
142+ logger .info (f"Folded { n_t_folded } weight Transpose nodes during weight compression" )
100143 print (f"fp8 qdq replaced with only dq completed in { end_time - start_time } s." )
101144
102145 return gs .export_onnx (graph )
103146
147+ @staticmethod
148+ def _move_mul_before_qdq (graph : gs .Graph ) -> int :
149+ """Move attention-scaling Mul(const) from after DQ to before Q for TRT MatMul fusion.
150+
151+ Handles both ``DQ → Mul → MatMul`` and ``DQ → Transpose → Mul → MatMul`` (K path).
152+ """
153+ count = 0
154+ for mul_node in list (graph .nodes ):
155+ if mul_node .op != "Mul" :
156+ continue
157+
158+ const_input = next (
159+ (i for i in mul_node .inputs if isinstance (i , gs .Constant ) and i .values .size == 1 ),
160+ None ,
161+ )
162+ tensor_input = next (
163+ (i for i in mul_node .inputs if not isinstance (i , gs .Constant )), None
164+ )
165+ if const_input is None or tensor_input is None :
166+ continue
167+ if not (isinstance (tensor_input , gs .Variable ) and len (tensor_input .inputs ) == 1 ):
168+ continue
169+
170+ producer = tensor_input .inputs [0 ]
171+ transpose_node = producer if producer .op == "Transpose" else None
172+ dq_node = producer if producer .op == "DequantizeLinear" else None
173+ if transpose_node is not None :
174+ t_input = transpose_node .inputs [0 ]
175+ if (
176+ isinstance (t_input , gs .Variable )
177+ and len (t_input .inputs ) == 1
178+ and t_input .inputs [0 ].op == "DequantizeLinear"
179+ ):
180+ dq_node = t_input .inputs [0 ]
181+ if dq_node is None :
182+ continue
183+
184+ q_output = dq_node .inputs [0 ]
185+ if (
186+ not isinstance (q_output , gs .Variable )
187+ or len (q_output .inputs ) != 1
188+ or q_output .inputs [0 ].op != "QuantizeLinear"
189+ ):
190+ continue
191+ q_node = q_output .inputs [0 ]
192+ q_input = q_node .inputs [0 ]
193+ if not isinstance (q_input , gs .Variable ):
194+ continue
195+
196+ mul_output = mul_node .outputs [0 ]
197+ if not any (c .op == "MatMul" for c in mul_output .outputs ):
198+ continue
199+
200+ new_mul_output = gs .Variable (
201+ q_input .name + "_scaled" , dtype = q_input .dtype , shape = q_input .shape
202+ )
203+ graph .nodes .append (
204+ gs .Node (
205+ op = "Mul" ,
206+ name = mul_node .name + "_moved" ,
207+ inputs = [q_input , const_input ],
208+ outputs = [new_mul_output ],
209+ )
210+ )
211+ q_node .inputs [0 ] = new_mul_output
212+
213+ replacement = (
214+ transpose_node .outputs [0 ] if transpose_node is not None else dq_node .outputs [0 ]
215+ )
216+ for consumer in list (mul_output .outputs ):
217+ for i , inp in enumerate (consumer .inputs ):
218+ if inp is mul_output :
219+ consumer .inputs [i ] = replacement
220+ mul_node .outputs .clear ()
221+ count += 1
222+
223+ graph .cleanup ().toposort ()
224+ return count
225+
226+ @staticmethod
227+ def _move_transpose_before_qdq (graph : gs .Graph ) -> int :
228+ """Move Transpose from ``DQ → Transpose → MatMul`` to ``Transpose → Q → DQ → MatMul`` (K path)."""
229+ count = 0
230+ for transpose_node in list (graph .nodes ):
231+ if transpose_node .op != "Transpose" :
232+ continue
233+
234+ t_input = transpose_node .inputs [0 ]
235+ if (
236+ not isinstance (t_input , gs .Variable )
237+ or len (t_input .inputs ) != 1
238+ or t_input .inputs [0 ].op != "DequantizeLinear"
239+ ):
240+ continue
241+ dq_node = t_input .inputs [0 ]
242+
243+ dq_input = dq_node .inputs [0 ]
244+ if (
245+ not isinstance (dq_input , gs .Variable )
246+ or len (dq_input .inputs ) != 1
247+ or dq_input .inputs [0 ].op != "QuantizeLinear"
248+ ):
249+ continue
250+ q_node = dq_input .inputs [0 ]
251+ q_input = q_node .inputs [0 ]
252+ if not isinstance (q_input , gs .Variable ):
253+ continue
254+
255+ t_output = transpose_node .outputs [0 ]
256+ if not any (c .op == "MatMul" for c in t_output .outputs ):
257+ continue
258+
259+ new_t_output = gs .Variable (q_input .name + "_transposed" , dtype = q_input .dtype )
260+ graph .nodes .append (
261+ gs .Node (
262+ op = "Transpose" ,
263+ name = transpose_node .name + "_moved" ,
264+ inputs = [q_input ],
265+ outputs = [new_t_output ],
266+ attrs = transpose_node .attrs ,
267+ )
268+ )
269+ q_node .inputs [0 ] = new_t_output
270+
271+ for consumer in list (t_output .outputs ):
272+ for i , inp in enumerate (consumer .inputs ):
273+ if inp is t_output :
274+ consumer .inputs [i ] = dq_node .outputs [0 ]
275+ transpose_node .outputs .clear ()
276+ count += 1
277+
278+ graph .cleanup ().toposort ()
279+ return count
280+
281+ @staticmethod
282+ def _insert_qdq_after_softmax (graph : gs .Graph ) -> int :
283+ """Insert FP8 Q→DQ on Softmax outputs feeding MatMul (required by TRT MHA fusion).
284+
285+ Torch export does not quantize softmax output; scale=1/448 saturates exactly at 1.0
286+ (softmax range is [0, 1]) while covering the full FP8 E4M3 representable range.
287+ """
288+ import numpy as np
289+
290+ count = 0
291+ for softmax_node in list (graph .nodes ):
292+ if softmax_node .op != "Softmax" :
293+ continue
294+ softmax_output = softmax_node .outputs [0 ]
295+ if not any (c .op == "MatMul" for c in softmax_output .outputs ):
296+ continue
297+ if any (c .op == "QuantizeLinear" for c in softmax_output .outputs ):
298+ continue
299+
300+ # Match scale dtype to the graph's current float dtype so TRT stronglyTyped
301+ # sees consistent Q/DQ types with the surrounding compute.
302+ scale_dtype = softmax_output .dtype if softmax_output .dtype is not None else np .float32
303+ scale_val = np .array (1.0 / 448.0 , dtype = scale_dtype )
304+ scale_constant = gs .Constant (softmax_node .name + "/softmax_q_scale" , scale_val )
305+ dq_scale_constant = gs .Constant (
306+ softmax_node .name + "/softmax_dq_scale" , scale_val .copy ()
307+ )
308+
309+ zp_tensor = onnx .TensorProto ()
310+ zp_tensor .data_type = onnx .TensorProto .FLOAT8E4M3FN
311+ zp_tensor .dims .extend ([1 ])
312+ zp_tensor .raw_data = b"\x00 "
313+ zp_constant = gs .Constant (
314+ softmax_node .name + "/softmax_q_zero_point" , LazyValues (zp_tensor )
315+ )
316+
317+ q_output = gs .Variable (softmax_node .name + "/q_output" )
318+ dq_output = gs .Variable (softmax_node .name + "/dq_output" , dtype = softmax_output .dtype )
319+ q_node = gs .Node (
320+ op = "QuantizeLinear" ,
321+ name = softmax_node .name + "/QuantizeLinear" ,
322+ inputs = [softmax_output , scale_constant , zp_constant ],
323+ outputs = [q_output ],
324+ attrs = {"saturate" : 1 },
325+ )
326+ dq_node = gs .Node (
327+ op = "DequantizeLinear" ,
328+ name = softmax_node .name + "/DequantizeLinear" ,
329+ inputs = [q_output , dq_scale_constant ],
330+ outputs = [dq_output ],
331+ )
332+ graph .nodes .extend ([q_node , dq_node ])
333+
334+ for consumer in list (softmax_output .outputs ):
335+ if consumer is q_node :
336+ continue
337+ for i , inp in enumerate (consumer .inputs ):
338+ if inp is softmax_output :
339+ consumer .inputs [i ] = dq_output
340+ count += 1
341+
342+ graph .cleanup ().toposort ()
343+ return count
344+
104345 @staticmethod
105346 def post_process (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
106347 """Post-processes the ONNX model for FP8 quantization.
107348
108- Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear:
109- - TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1
110- - TRT_FP8DequantizeLinear -> DequantizeLinear
349+ Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear and
350+ rewrites attention scaling / K-transpose / softmax-output patterns so TRT
351+ can fuse DQ into the attention MatMul kernels.
111352
112353 Args:
113354 onnx_model: The ONNX model containing TRT_FP8 quantization nodes.
@@ -144,5 +385,15 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
144385 f"Converted { node .name } from TRT_FP8DequantizeLinear to DequantizeLinear"
145386 )
146387
388+ # Attention-aware rewrites so TRT can fuse DQ into the attention MatMuls.
389+ n_mul = FP8QuantExporter ._move_mul_before_qdq (graph )
390+ n_t = FP8QuantExporter ._move_transpose_before_qdq (graph )
391+ n_sm = FP8QuantExporter ._insert_qdq_after_softmax (graph )
392+ if n_mul or n_t or n_sm :
393+ logger .info (
394+ f"Attention QDQ rewrites: moved { n_mul } Mul, { n_t } Transpose; "
395+ f"inserted QDQ on { n_sm } Softmax outputs"
396+ )
397+
147398 graph .cleanup ().toposort ()
148399 return gs .export_onnx (graph )
0 commit comments