1313aten = torch .ops .aten
1414from tensorrt_llm .mapping import Mapping
1515
16+ from . import MATCHER_SUBSYSTEM
17+
18+
19+ def _append_named_pass (custom_passes : List [PatternMatcherPass ], pass_name : str ):
20+ custom_passes .append (PatternMatcherPass (pass_name , MATCHER_SUBSYSTEM ))
21+
22+
23+ def _check_getitem_only_users (match : Match , pattern_node ) -> bool :
24+ node = match .ctx .pattern_to_node [pattern_node ]
25+ if not isinstance (node , torch .fx .graph .Node ):
26+ return False
27+ for user in node .users :
28+ if user .op != "call_function" or user .target is not getitem :
29+ return False
30+ return True
31+
32+
33+ def _has_getitem_user (match : Match , pattern_node , index : int ) -> bool :
34+ node = match .ctx .pattern_to_node [pattern_node ]
35+ if not isinstance (node , torch .fx .graph .Node ):
36+ return False
37+ for user in node .users :
38+ if (user .op == "call_function" and user .target is getitem
39+ and user .args [1 ] == index ):
40+ return True
41+ return False
42+
43+
44+ def _make_fp8_quant_extra_check (input_node , strategy_node , quant_node ,
45+ require_scale_output : bool ):
46+
47+ def extra_check (match : Match ) -> bool :
48+ return (check_f16_bf16_input (match , input_node )
49+ and check_non_ub_strategy (match , strategy_node )
50+ and _check_getitem_only_users (match , quant_node ) and
51+ _has_getitem_user (match , quant_node , 1 ) == require_scale_output )
52+
53+ return extra_check
54+
1655
1756def register_ar_residual_norm (custom_pass : PatternMatcherPass , mapping : Mapping ,
1857 allreduce_func : Callable ):
@@ -134,15 +173,16 @@ def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass,
134173 torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor .default ,
135174 getitem_0 ,
136175 KeywordArg ("scale" ),
137- _users = 2 )
138- getitem_2 = CallFunction (getitem ,
139- static_quantize_e4m3_per_tensor_default ,
140- 0 ,
141- _users = 2 )
176+ _users = MULTIPLE )
177+ getitem_2 = CallFunction (getitem , static_quantize_e4m3_per_tensor_default ,
178+ 0 )
142179 getitem_3 = CallFunction (getitem , static_quantize_e4m3_per_tensor_default ,
143180 1 )
144- pattern = MultiOutputPattern ([getitem_0 , getitem_1 , getitem_2 , getitem_3
145- ]) # norm_out, residual_out, quant_out, scale
181+ pattern_with_scale = MultiOutputPattern (
182+ [getitem_0 , getitem_1 , getitem_2 ,
183+ getitem_3 ]) # norm_out, residual_out, quant_out, scale
184+ pattern_without_scale = MultiOutputPattern (
185+ [getitem_0 , getitem_1 , getitem_2 ]) # norm_out, residual_out, quant_out
146186
147187 def empty_pattern (
148188 input : torch .Tensor ,
@@ -173,18 +213,48 @@ def target_pattern(
173213 trigger_completion_at_end )
174214 return allreduce [0 ], allreduce [2 ], allreduce [1 ], scale
175215
176- def extra_check (match : Match ) -> bool :
177- return check_f16_bf16_input (
178- match , input_node ) and check_non_ub_strategy (match , strategy_node )
216+ def target_pattern_without_scale (
217+ input : torch .Tensor ,
218+ residual : torch .Tensor ,
219+ gamma : torch .Tensor ,
220+ workspace : torch .LongTensor ,
221+ strategy : int ,
222+ eps : float ,
223+ scale : torch .Tensor ,
224+ trigger_completion_at_end : bool ,
225+ ):
226+ allreduce = allreduce_func (
227+ input , residual , gamma , scale , None , workspace , mapping .tp_group ,
228+ int (strategy ),
229+ int (AllReduceFusionOp .RESIDUAL_RMS_NORM_OUT_QUANT_FP8 ), float (eps ),
230+ trigger_completion_at_end )
231+ return allreduce [0 ], allreduce [2 ], allreduce [1 ]
232+
233+ extra_check_with_scale = _make_fp8_quant_extra_check (
234+ input_node , strategy_node , static_quantize_e4m3_per_tensor_default ,
235+ True )
236+ extra_check_without_scale = _make_fp8_quant_extra_check (
237+ input_node , strategy_node , static_quantize_e4m3_per_tensor_default ,
238+ False )
179239
180240 register_replacement (
181241 empty_pattern ,
182242 target_pattern ,
183243 [],
184244 fwd_only ,
185245 custom_pass ,
186- search_fn_pattern = pattern ,
187- extra_check = extra_check ,
246+ search_fn_pattern = pattern_with_scale ,
247+ extra_check = extra_check_with_scale ,
248+ )
249+
250+ register_replacement (
251+ empty_pattern ,
252+ target_pattern_without_scale ,
253+ [],
254+ fwd_only ,
255+ custom_pass ,
256+ search_fn_pattern = pattern_without_scale ,
257+ extra_check = extra_check_without_scale ,
188258 )
189259
190260
@@ -212,15 +282,15 @@ def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass,
212282 torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor .default ,
213283 getitem_0 ,
214284 KeywordArg ("scale" ),
215- _users = 2 )
216- getitem_2 = CallFunction (getitem ,
217- static_quantize_e4m3_per_tensor_default ,
218- 0 ,
219- _users = 2 )
285+ _users = MULTIPLE )
286+ getitem_2 = CallFunction (getitem , static_quantize_e4m3_per_tensor_default ,
287+ 0 )
220288 getitem_3 = CallFunction (getitem , static_quantize_e4m3_per_tensor_default ,
221289 1 )
222- pattern = MultiOutputPattern ([getitem_1 , getitem_2 ,
223- getitem_3 ]) # residual_out, quant_out, scale
290+ pattern_with_scale = MultiOutputPattern (
291+ [getitem_1 , getitem_2 , getitem_3 ]) # residual_out, quant_out, scale
292+ pattern_without_scale = MultiOutputPattern ([getitem_1 , getitem_2
293+ ]) # residual_out, quant_out
224294
225295 def empty_pattern (
226296 input : torch .Tensor ,
@@ -250,18 +320,47 @@ def target_pattern(
250320 float (eps ), trigger_completion_at_end )
251321 return allreduce [1 ], allreduce [0 ], scale
252322
253- def extra_check (match : Match ) -> bool :
254- return check_f16_bf16_input (
255- match , input_node ) and check_non_ub_strategy (match , strategy_node )
323+ def target_pattern_without_scale (
324+ input : torch .Tensor ,
325+ residual : torch .Tensor ,
326+ gamma : torch .Tensor ,
327+ workspace : torch .LongTensor ,
328+ strategy : int ,
329+ eps : float ,
330+ scale : torch .Tensor ,
331+ trigger_completion_at_end : bool ,
332+ ):
333+ allreduce = allreduce_func (
334+ input , residual , gamma , scale , None , workspace , mapping .tp_group ,
335+ int (strategy ), int (AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_FP8 ),
336+ float (eps ), trigger_completion_at_end )
337+ return allreduce [1 ], allreduce [0 ]
338+
339+ extra_check_with_scale = _make_fp8_quant_extra_check (
340+ input_node , strategy_node , static_quantize_e4m3_per_tensor_default ,
341+ True )
342+ extra_check_without_scale = _make_fp8_quant_extra_check (
343+ input_node , strategy_node , static_quantize_e4m3_per_tensor_default ,
344+ False )
256345
257346 register_replacement (
258347 empty_pattern ,
259348 target_pattern ,
260349 [],
261350 fwd_only ,
262351 custom_pass ,
263- search_fn_pattern = pattern ,
264- extra_check = extra_check ,
352+ search_fn_pattern = pattern_with_scale ,
353+ extra_check = extra_check_with_scale ,
354+ )
355+
356+ register_replacement (
357+ empty_pattern ,
358+ target_pattern_without_scale ,
359+ [],
360+ fwd_only ,
361+ custom_pass ,
362+ search_fn_pattern = pattern_without_scale ,
363+ extra_check = extra_check_without_scale ,
265364 )
266365
267366
@@ -772,16 +871,20 @@ def extra_check(match: Match) -> bool:
772871 extra_check = extra_check ,
773872 )
774873
775- custom_passes .append (PatternMatcherPass ())
874+ _append_named_pass (
875+ custom_passes ,
876+ f"ub_convert_supported_ar_to_ub:{ allreduce_func .__name__ } " )
776877 register_convert_supported_ar_to_ub (custom_passes [- 1 ])
777878
778- custom_passes . append ( PatternMatcherPass () )
879+ _append_named_pass ( custom_passes , f"ub_prologue: { allreduce_func . __name__ } " )
779880 register_ub_prologue_patterns (custom_passes [- 1 ])
780881
781- custom_passes . append ( PatternMatcherPass () )
882+ _append_named_pass ( custom_passes , f"ub_finalize: { allreduce_func . __name__ } " )
782883 register_ub_finalize_patterns (custom_passes [- 1 ])
783884
784- custom_passes .append (PatternMatcherPass ())
885+ _append_named_pass (
886+ custom_passes ,
887+ f"insert_copy_for_graph_output:{ allreduce_func .__name__ } " )
785888 insert_copy_for_graph_output (custom_passes [- 1 ])
786889
787890
@@ -792,7 +895,7 @@ def register_ar_fusions(custom_passes: List[PatternMatcherPass],
792895 register_ar_residual_norm (custom_passes [- 1 ], mapping ,
793896 torch .ops .trtllm .tunable_allreduce )
794897
795- custom_passes . append ( PatternMatcherPass () )
898+ _append_named_pass ( custom_passes , "ar_residual_norm_quant" )
796899 for allreduce_func in [
797900 torch .ops .trtllm .allreduce , torch .ops .trtllm .tunable_allreduce
798901 ]:
0 commit comments