@@ -92,6 +92,8 @@ def fused_stack_quant(expert_weight_list, transpose=False):
9292 w , scale = _get_fp8_weight_and_scale (expert_weight_list [0 ], stacked = True , transpose = True )
9393 elif transpose is True and hasattr (expert_weight_list [0 ], "fp8_weight_stacked" ):
9494 w , scale = _get_fp8_weight_and_scale (expert_weight_list [0 ], stacked = True , transpose = False )
95+ elif transpose is False and hasattr (expert_weight_list [0 ], "fp8_weight_stacked_transpose" ):
96+ w , scale = _get_fp8_weight_and_scale (expert_weight_list [0 ], stacked = True , transpose = True )
9597 else :
9698 w , scale = paddle .incubate .nn .functional .fused_stack_transpose_quant (expert_weight_list , transpose = transpose )
9799 return w , scale
@@ -114,6 +116,8 @@ def weight_quant(weight, transpose=False):
114116 else :
115117 if hasattr (weight , "fp8_weight" ):
116118 return weight .fp8_weight , weight .fp8_scale
119+ elif hasattr (weight , "fp8_weight_transpose" ):
120+ return weight .fp8_weight_transpose .T .contiguous (), weight .fp8_scale_transpose .T .contiguous ()
117121 else :
118122 return paddle .incubate .nn .functional .fp8_quant_blockwise (
119123 weight ,
@@ -596,23 +600,33 @@ def forward(self, x):
596600 return FP8LinearFunction .apply (x , self , keep_x = False )
597601
598602
599- def cache_fp8_weight (weight , quant_transpose = True ):
600- if hasattr (weight , "fp8_weight" ):
603+ def cache_fp8_weight (weight , quant_transpose = None ):
604+ if hasattr (weight , "fp8_weight" ) or hasattr ( weight , "fp8_weight_transpose" ) :
601605 return
602-
603- if quant_transpose :
606+ if quant_transpose is None :
604607 w_fp8 , w_scale , w_t_fp8 , w_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
605608 weight ,
606609 output_scale_transpose = False ,
607610 quant_method = "128x128" ,
608611 input_transpose = True ,
609612 return_transpose_only = False ,
610613 )
614+
611615 setattr (weight , "fp8_weight_transpose" , w_t_fp8 )
612616 setattr (weight , "fp8_scale_transpose" , w_t_scale )
613617 setattr (weight , "fp8_weight" , w_fp8 )
614618 setattr (weight , "fp8_scale" , w_scale )
615- else :
619+ elif quant_transpose is True :
620+ w_t_fp8 , w_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
621+ weight ,
622+ output_scale_transpose = False ,
623+ quant_method = "128x128" ,
624+ input_transpose = True ,
625+ return_transpose_only = True ,
626+ )
627+ setattr (weight , "fp8_weight_transpose" , w_t_fp8 )
628+ setattr (weight , "fp8_scale_transpose" , w_t_scale )
629+ elif quant_transpose is False :
616630 w_fp8 , w_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
617631 weight ,
618632 output_scale_transpose = False ,
@@ -622,6 +636,8 @@ def cache_fp8_weight(weight, quant_transpose=True):
622636 )
623637 setattr (weight , "fp8_weight" , w_fp8 )
624638 setattr (weight , "fp8_scale" , w_scale )
639+ else :
640+ raise ValueError ("quant_transpose must be either True, False or None." )
625641
626642
627643class FP8KeepXLinear (paddle .nn .Layer ):
@@ -636,7 +652,7 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
636652 )
637653 set_parameter_color ([self .weight ], "attn_out_project" )
638654
639- def fp8_quant_weight (self , quant_transpose = True ):
655+ def fp8_quant_weight (self , quant_transpose = None ):
640656 cache_fp8_weight (self .weight , quant_transpose = quant_transpose )
641657
642658 def forward (self , x ):
@@ -798,7 +814,7 @@ def __init__(
798814 is_bias = False ,
799815 )
800816
801- def fp8_quant_weight (self , quant_transpose = True ):
817+ def fp8_quant_weight (self , quant_transpose = None ):
802818 cache_fp8_weight (self .w1 , quant_transpose )
803819 cache_fp8_weight (self .w2 , quant_transpose )
804820
@@ -980,6 +996,10 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, tokens_per_expert, m_indi
980996 bw_w2_quant = bw_w2_quant .reshape ([len (expert_w2 ), - 1 , bw_w2_quant .shape [- 1 ]])
981997 bw_w2_scale = bw_w2_scale .reshape ([len (expert_w2 ), - 1 , bw_w2_scale .shape [- 1 ]])
982998
999+ if hasattr (expert_w2 [0 ], "fp8_weight_stacked_transpose" ) and not hasattr (expert_w2 [0 ], "fp8_weight_stacked" ):
1000+ bw_w2_quant = bw_w2_quant .contiguous ().transpose ([0 , 2 , 1 ]).contiguous ()
1001+ bw_w2_scale = bw_w2_scale .contiguous ().transpose ([0 , 2 , 1 ]).contiguous ()
1002+
9831003 # compute gemm
9841004 if isinstance (unzipped_grad , tuple ):
9851005 (unzipped_grad_fp8 , unzipped_grad_scale ) = unzipped_grad
@@ -1024,6 +1044,10 @@ def bwd_gate_up_input(self, do1, expert_w1, tokens_per_expert, m_indices=None, d
10241044 bw_w1_quant = bw_w1_quant .reshape ([len (expert_w1 ), - 1 , bw_w1_quant .shape [- 1 ]])
10251045 bw_w1_scale = bw_w1_scale .reshape ([len (expert_w1 ), - 1 , bw_w1_scale .shape [- 1 ]])
10261046
1047+ if hasattr (expert_w1 [0 ], "fp8_weight_stacked_transpose" ) and not hasattr (expert_w1 [0 ], "fp8_weight_stacked" ):
1048+ bw_w1_quant = bw_w1_quant .contiguous ().transpose ([0 , 2 , 1 ]).contiguous ()
1049+ bw_w1_scale = bw_w1_scale .contiguous ().transpose ([0 , 2 , 1 ]).contiguous ()
1050+
10271051 # quant do1
10281052 do1_fp8 , do1_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
10291053 do1 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
0 commit comments