2828from .fused_moe_backend_base import UnquantizedFusedMoEMethod
2929
3030if current_platform .is_cuda ():
31- from fastdeploy .model_executor .ops .gpu import (
32- count_tokens_per_expert_func ,
33- moe_expert_dispatch ,
34- moe_expert_reduce ,
35- )
31+ from fastdeploy .model_executor .ops .gpu import moe_expert_dispatch , moe_expert_reduce
3632
3733 try :
3834 from fastdeploy .model_executor .ops .gpu import (
4339 logger .warning ("import w4afp8_gemm_scale_permute Failed!" )
4440
4541from fastdeploy .model_executor .layers .moe .moe import get_moe_scores
42+ from fastdeploy .model_executor .layers .quantization .fp8_utils import paddlefleet_ops
4643from fastdeploy .model_executor .utils import (
4744 TensorTracker ,
4845 free_tensor ,
5249)
5350
5451
52+ def deep_batch_gemm (x , y , expert_idx_per_token ):
53+ out = paddle .empty ([x .shape [0 ], y .shape [- 1 ]], dtype = x .dtype )
54+ paddlefleet_ops .deep_gemm .m_grouped_bf16_gemm_nn_contiguous (x , y , out , expert_idx_per_token )
55+ return out
56+
57+
5558class CutlassMoEMethod (UnquantizedFusedMoEMethod ):
5659 """
5760 Use Cutlass Group Gemm to compute Fused MoE.
@@ -155,30 +158,22 @@ def apply_ep_prefill(
155158 if fastdeploy .envs .FD_USE_PHI_MOE_PERMUTE and self .moe_quant_type == "w16a16" :
156159 # --- moe_permute / moe_unpermute path ---
157160 recv_topk_idx_i32 = recv_topk_idx .astype (paddle .int32 )
158- (permute_input , permute_indices_per_token , dst_weights , _scale_out ) = paddle .nn .functional .moe_permute (
159- hidden_states = recv_x ,
160- scale = None ,
161- expert_routemap_topk = recv_topk_idx_i32 ,
162- expert_prob_topk = recv_topk_weights ,
163- num_experts = layer .num_local_experts ,
164- tokens_per_expert = [],
165- padding_alignment = 128 ,
166- override_buffer_size = token_all_num ,
161+ (permute_input , permute_indices_per_token , dst_weights , _scale_out , expert_idx_per_token ) = (
162+ paddle .nn .functional .moe_permute (
163+ hidden_states = recv_x ,
164+ scale = None ,
165+ expert_routemap_topk = recv_topk_idx_i32 ,
166+ expert_prob_topk = recv_topk_weights ,
167+ num_experts = layer .num_local_experts ,
168+ tokens_per_expert = [],
169+ padding_alignment = 128 ,
170+ override_buffer_size = token_all_num ,
171+ )
167172 )
168173
169- token_nums_per_expert_cumsum = count_tokens_per_expert_func (
170- recv_topk_idx , layer .num_local_experts , True
171- )[2 ].cast (paddle .int64 )
172- ffn_out = self .compute_ffn (
173- layer ,
174- permute_input ,
175- token_nums_per_expert_cumsum ,
176- None ,
177- False ,
178- - 1 ,
179- None ,
180- None ,
181- )
174+ out = deep_batch_gemm (permute_input , getattr (layer , self .added_weight_attrs [0 ]), expert_idx_per_token )
175+ out = paddlefleet_ops .fused_swiglu_scale (out , dst_weights )
176+ ffn_out = deep_batch_gemm (out , getattr (layer , self .added_weight_attrs [1 ]), expert_idx_per_token )
182177
183178 tmp_ffn_out , _out_probs = paddle .nn .functional .moe_unpermute (
184179 hidden_states_unzipped = ffn_out ,
@@ -187,7 +182,7 @@ def apply_ep_prefill(
187182 token_prob_unzipped = dst_weights ,
188183 total_zipped_tokens = recv_x .shape [0 ],
189184 num_experts = layer .num_local_experts ,
190- using_weighted_combine = True ,
185+ using_weighted_combine = False ,
191186 )
192187 else :
193188 # --- original ep_moe_expert_dispatch / combine path ---
@@ -339,36 +334,33 @@ def apply_tp(
339334 )
340335 topk_idx_i32 = topk_idx .astype (paddle .int32 )
341336 override_buffer_size = x .shape [0 ] * layer .top_k + layer .num_experts * (128 - 1 )
342- (permute_input , permute_indices_per_token , dst_weights , _scale_out ) = ( # zipped_expertwise_rowmap
343- paddle .nn .functional .moe_permute (
344- hidden_states = x ,
345- scale = None ,
346- expert_routemap_topk = topk_idx_i32 ,
347- expert_prob_topk = topk_weights ,
348- num_experts = layer .num_experts ,
349- tokens_per_expert = [],
350- padding_alignment = 128 ,
351- override_buffer_size = override_buffer_size ,
352- )
337+ (
338+ permute_input ,
339+ permute_indices_per_token ,
340+ dst_weights ,
341+ _scale_out ,
342+ expert_idx_per_token ,
343+ ) = paddle .nn .functional .moe_permute ( # zipped_expertwise_rowmap
344+ hidden_states = x ,
345+ scale = None ,
346+ expert_routemap_topk = topk_idx_i32 ,
347+ expert_prob_topk = topk_weights ,
348+ num_experts = layer .num_experts ,
349+ tokens_per_expert = [],
350+ padding_alignment = 128 ,
351+ override_buffer_size = override_buffer_size ,
352+ return_expert_indices = True ,
353353 )
354354
355- # Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert.
356- token_nums_per_expert_cumsum = count_tokens_per_expert_func (topk_idx , layer .num_experts , True )[2 ].cast (
357- paddle .int64
358- )
359355 if topk_ids_hookfunc is not None :
360356 topk_ids_hookfunc (topk_ids = topk_idx )
361357
362- ffn_out = self .compute_ffn (
363- layer ,
364- permute_input ,
365- token_nums_per_expert_cumsum ,
366- None , # expert_idx_per_token not needed for w16a16 without bias
367- False ,
368- - 1 ,
369- None , # dequant_scale
370- None , # max_tokens_per_expert
371- )
358+ out = deep_batch_gemm (permute_input , getattr (layer , self .added_weight_attrs [0 ]), expert_idx_per_token )
359+ out = paddlefleet_ops .fused_swiglu_scale (out , dst_weights )
360+ ffn_out = deep_batch_gemm (out , getattr (layer , self .added_weight_attrs [1 ]), expert_idx_per_token )
361+ if layer .with_bias :
362+ down_proj_bias_expand = paddle .index_select (layer .down_proj_bias , expert_idx_per_token , axis = 0 )
363+ ffn_out = paddle .add (ffn_out , down_proj_bias_expand )
372364
373365 fused_moe_out , _out_probs = paddle .nn .functional .moe_unpermute (
374366 hidden_states_unzipped = ffn_out ,
@@ -377,7 +369,7 @@ def apply_tp(
377369 token_prob_unzipped = dst_weights ,
378370 total_zipped_tokens = x .shape [0 ],
379371 num_experts = layer .num_experts ,
380- using_weighted_combine = True ,
372+ using_weighted_combine = False ,
381373 )
382374 return fused_moe_out
383375
0 commit comments