@@ -583,11 +583,13 @@ def extra_check_convert_supported_ar_to_ub(match: Match) -> bool:
583583 def register_ub_prologue_patterns (custom_pass : PatternMatcherPass ):
584584
585585 def register_scaled_mm_prologue (custom_pass : PatternMatcherPass ):
586+ output_buffer_kind_key = KeywordArg ('output_buffer_kind' )
586587 trtllm_cublas_scaled_mm_default = CallFunction (
587588 torch .ops .trtllm .cublas_scaled_mm .default , KeywordArg ('mm0_a' ),
588589 KeywordArg ('mm0_b' ), KeywordArg ('mm0_a_scale' ),
589590 KeywordArg ('mm0_b_scale' ), KeywordArg ('mm0_bias' ),
590- KeywordArg ('mm_dtype' ))
591+ KeywordArg ('mm_dtype' ), output_buffer_kind_key ,
592+ mapping .tp_group )
591593 ub_copy = CallFunction (torch .ops .trtllm .copy_to_userbuffers ,
592594 trtllm_cublas_scaled_mm_default )
593595
@@ -598,6 +600,7 @@ def empty_scaled_mm_prologue_pattern(
598600 mm0_b_scale : torch .Tensor ,
599601 mm0_bias : Optional [torch .Tensor ],
600602 mm_dtype : torch .dtype ,
603+ output_buffer_kind : int ,
601604 ):
602605 return
603606
@@ -608,10 +611,11 @@ def target_scaled_mm_prologue_pattern(
608611 mm0_b_scale : torch .Tensor ,
609612 mm0_bias : Optional [torch .Tensor ],
610613 mm_dtype : torch .dtype ,
614+ output_buffer_kind : int ,
611615 ):
612616 scaled_mm_output = torch .ops .trtllm .cublas_scaled_mm (
613617 mm0_a , mm0_b , mm0_a_scale , mm0_b_scale , mm0_bias , mm_dtype ,
614- True )
618+ int ( BufferKind . USERBUFFERS ), mapping . tp_group )
615619 return scaled_mm_output
616620
617621 # No extra check needed as the output dtype of scaled_mm has been verified when
@@ -635,15 +639,9 @@ def register_nvfp4_gemm_prologue(custom_pass: PatternMatcherPass):
635639 output_buffer_kind_key = KeywordArg ('output_buffer_kind' )
636640 allowed_backends_key = KeywordArg ('allowed_backends' )
637641 trtllm_nvfp4_gemm_default = CallFunction (
638- torch .ops .trtllm .nvfp4_gemm .default ,
639- act_fp4_key ,
640- weight_key ,
641- act_sf_key ,
642- weight_scale_key ,
643- alpha_key ,
644- output_dtype_key ,
645- output_buffer_kind = output_buffer_kind_key ,
646- allowed_backends = allowed_backends_key )
642+ torch .ops .trtllm .nvfp4_gemm .default , act_fp4_key , weight_key ,
643+ act_sf_key , weight_scale_key , alpha_key , output_dtype_key ,
644+ output_buffer_kind_key , allowed_backends_key , mapping .tp_group )
647645 ub_copy = CallFunction (torch .ops .trtllm .copy_to_userbuffers ,
648646 trtllm_nvfp4_gemm_default )
649647
@@ -671,7 +669,48 @@ def target_nvfp4_gemm_prologue_pattern(
671669 ):
672670 nvfp4_gemm_output = torch .ops .trtllm .nvfp4_gemm (
673671 act_fp4 , weight , act_sf , weight_scale , alpha , output_dtype ,
674- int (BufferKind .USERBUFFERS ), allowed_backends )
672+ int (BufferKind .USERBUFFERS ), allowed_backends ,
673+ mapping .tp_group )
674+ return nvfp4_gemm_output
675+
676+ bias_key = KeywordArg ('bias' )
677+ trtllm_nvfp4_gemm_with_bias_default = CallFunction (
678+ torch .ops .trtllm .nvfp4_gemm .default , act_fp4_key , weight_key ,
679+ act_sf_key , weight_scale_key , alpha_key , output_dtype_key ,
680+ output_buffer_kind_key , allowed_backends_key , mapping .tp_group ,
681+ bias_key )
682+ ub_copy_with_bias = CallFunction (
683+ torch .ops .trtllm .copy_to_userbuffers ,
684+ trtllm_nvfp4_gemm_with_bias_default )
685+
686+ def empty_nvfp4_gemm_bias_prologue_pattern (
687+ act_fp4 : torch .Tensor ,
688+ weight : torch .Tensor ,
689+ act_sf : torch .Tensor ,
690+ weight_scale : torch .Tensor ,
691+ alpha : torch .Tensor ,
692+ output_dtype : torch .dtype ,
693+ output_buffer_kind : int ,
694+ allowed_backends : str ,
695+ bias : Optional [torch .Tensor ],
696+ ):
697+ return
698+
699+ def target_nvfp4_gemm_bias_prologue_pattern (
700+ act_fp4 : torch .Tensor ,
701+ weight : torch .Tensor ,
702+ act_sf : torch .Tensor ,
703+ weight_scale : torch .Tensor ,
704+ alpha : torch .Tensor ,
705+ output_dtype : torch .dtype ,
706+ output_buffer_kind : int ,
707+ allowed_backends : str ,
708+ bias : Optional [torch .Tensor ],
709+ ):
710+ nvfp4_gemm_output = torch .ops .trtllm .nvfp4_gemm (
711+ act_fp4 , weight , act_sf , weight_scale , alpha , output_dtype ,
712+ int (BufferKind .USERBUFFERS ), allowed_backends ,
713+ mapping .tp_group , bias )
675714 return nvfp4_gemm_output
676715
677716 def extra_check (match : Match ) -> bool :
@@ -702,6 +741,15 @@ def extra_check(match: Match) -> bool:
702741 search_fn_pattern = ub_copy ,
703742 extra_check = extra_check ,
704743 )
744+ register_replacement (
745+ empty_nvfp4_gemm_bias_prologue_pattern ,
746+ target_nvfp4_gemm_bias_prologue_pattern ,
747+ [],
748+ fwd_only ,
749+ custom_pass ,
750+ search_fn_pattern = ub_copy_with_bias ,
751+ extra_check = extra_check ,
752+ )
705753
706754 def register_mm_prologue (custom_pass : PatternMatcherPass ):
707755 aten_mm_default = CallFunction (aten .mm .default , KeywordArg ('mm0_a' ),
0 commit comments