2828from .fused_moe_backend_base import UnquantizedFusedMoEMethod
2929
3030if current_platform .is_cuda ():
31- from fastdeploy .model_executor .ops .gpu import moe_expert_dispatch , moe_expert_reduce
31+ from fastdeploy .model_executor .ops .gpu import (
32+ count_tokens_per_expert_func ,
33+ moe_expert_dispatch ,
34+ moe_expert_reduce ,
35+ )
3236
3337 try :
3438 from fastdeploy .model_executor .ops .gpu import (
@@ -145,14 +149,15 @@ def apply_ep_prefill(
145149 # 1. Select topk experts and weights
146150 topk_idx , topk_weights = self .ep_prefill_runner .moe_select (layer , gate_out )
147151 # 2. EP Dispatch
152+ dispatch_kwargs = {"expert_alignment" : 128 } if fastdeploy .envs .FD_USE_PHI_MOE_PERMUTE else {}
148153 (
149154 recv_x ,
150155 recv_topk_idx ,
151156 recv_topk_weights ,
152157 recv_num_tokens_per_expert_list ,
153158 handle ,
154159 event ,
155- ) = self .ep_prefill_runner .dispatch (x , topk_idx , topk_weights )
160+ ) = self .ep_prefill_runner .dispatch (x , topk_idx , topk_weights , ** dispatch_kwargs )
156161
157162 if topk_ids_hookfunc is not None :
158163 topk_ids_hookfunc (topk_ids = topk_idx )
@@ -165,54 +170,91 @@ def apply_ep_prefill(
165170 # 3. Compute ffn
166171 if token_all_num > 0 :
167172 logger .debug (f"token_all_num { token_all_num } " )
168- (
169- permute_input ,
170- permute_indices_per_token ,
171- recv_num_tokens_per_expert_list_cumsum ,
172- dst_weights ,
173- dst_indices ,
174- cumsum_idx_gpu ,
175- expert_idx_per_token ,
176- dequant_scale ,
177- ) = fastdeploy .model_executor .ops .gpu .ep_moe_expert_dispatch (
178- recv_x ,
179- recv_topk_idx ,
180- recv_topk_weights ,
181- (layer .up_gate_proj_in_scale if hasattr (layer , "up_gate_proj_in_scale" ) else None ),
182- recv_num_tokens_per_expert_list ,
183- token_all_num ,
184- self .moe_quant_type ,
185- )
186- if not layer .with_bias and self .moe_quant_type != "w4a8" and self .moe_quant_type != "w4afp8" :
187- # only w4a8 and w4afp8 need expert_idx_per_token
188- # Other need not this tensor, so we make it None.
189- expert_idx_per_token = None
173+
174+ if fastdeploy .envs .FD_USE_PHI_MOE_PERMUTE and self .moe_quant_type == "w16a16" :
175+ # --- moe_permute / moe_unpermute path ---
176+ recv_topk_idx_i32 = recv_topk_idx .astype (paddle .int32 )
177+ (permute_input , permute_indices_per_token , dst_weights , _scale_out ) = paddle .nn .functional .moe_permute (
178+ hidden_states = recv_x ,
179+ scale = None ,
180+ expert_routemap_topk = recv_topk_idx_i32 ,
181+ expert_prob_topk = recv_topk_weights ,
182+ num_experts = layer .num_local_experts ,
183+ tokens_per_expert = [],
184+ padding_alignment = 128 ,
185+ override_buffer_size = token_all_num ,
186+ )
187+
188+ token_nums_per_expert_cumsum = count_tokens_per_expert_func (
189+ recv_topk_idx , layer .num_local_experts , True
190+ )[2 ].cast (paddle .int64 )
191+ ffn_out = self .compute_ffn (
192+ layer ,
193+ permute_input ,
194+ token_nums_per_expert_cumsum ,
195+ None ,
196+ False ,
197+ - 1 ,
198+ None ,
199+ None ,
200+ )
201+
202+ tmp_ffn_out , _out_probs = paddle .nn .functional .moe_unpermute (
203+ hidden_states_unzipped = ffn_out ,
204+ zipped_expertwise_rowmap = permute_indices_per_token ,
205+ expert_routemap_topk = recv_topk_idx_i32 ,
206+ token_prob_unzipped = dst_weights ,
207+ total_zipped_tokens = recv_x .shape [0 ],
208+ num_experts = layer .num_local_experts ,
209+ using_weighted_combine = True ,
210+ )
190211 else :
191- expert_idx_per_token = expert_idx_per_token .cast ("int64" )
212+ # --- original ep_moe_expert_dispatch / combine path ---
213+ (
214+ permute_input ,
215+ permute_indices_per_token ,
216+ recv_num_tokens_per_expert_list_cumsum ,
217+ dst_weights ,
218+ dst_indices ,
219+ cumsum_idx_gpu ,
220+ expert_idx_per_token ,
221+ dequant_scale ,
222+ ) = fastdeploy .model_executor .ops .gpu .ep_moe_expert_dispatch (
223+ recv_x ,
224+ recv_topk_idx ,
225+ recv_topk_weights ,
226+ (layer .up_gate_proj_in_scale if hasattr (layer , "up_gate_proj_in_scale" ) else None ),
227+ recv_num_tokens_per_expert_list ,
228+ token_all_num ,
229+ self .moe_quant_type ,
230+ )
231+ if not layer .with_bias and self .moe_quant_type != "w4a8" and self .moe_quant_type != "w4afp8" :
232+ expert_idx_per_token = None
233+ else :
234+ expert_idx_per_token = expert_idx_per_token .cast ("int64" )
192235
193- if hasattr (layer , "up_gate_proj_in_scale" ):
194- dequant_scale = None
236+ if hasattr (layer , "up_gate_proj_in_scale" ):
237+ dequant_scale = None
195238
196- ffn_out = self .compute_ffn (
197- layer ,
198- permute_input ,
199- recv_num_tokens_per_expert_list_cumsum ,
200- expert_idx_per_token ,
201- False ,
202- - 1 ,
203- dequant_scale ,
204- )
239+ ffn_out = self .compute_ffn (
240+ layer ,
241+ permute_input ,
242+ recv_num_tokens_per_expert_list_cumsum ,
243+ expert_idx_per_token ,
244+ False ,
245+ - 1 ,
246+ dequant_scale ,
247+ )
205248
206- # prmt back per rank
207- tmp_ffn_out = fastdeploy .model_executor .ops .gpu .ep_moe_expert_combine (
208- ffn_out ,
209- dst_weights ,
210- permute_indices_per_token ,
211- dst_indices ,
212- None , # down_proj_bias,
213- False , # norm_topk_prob
214- 1.0 ,
215- )
249+ tmp_ffn_out = fastdeploy .model_executor .ops .gpu .ep_moe_expert_combine (
250+ ffn_out ,
251+ dst_weights ,
252+ permute_indices_per_token ,
253+ dst_indices ,
254+ None , # down_proj_bias,
255+ False , # norm_topk_prob
256+ 1.0 ,
257+ )
216258 else :
217259 tmp_ffn_out = recv_x
218260
@@ -292,6 +334,69 @@ def apply_tp(
292334 Paddle Cutlass compute Fused MoE.
293335 """
294336 gate_out = gate (x .cast ("float32" ))
337+ if fastdeploy .envs .FD_USE_PHI_MOE_PERMUTE and self .moe_quant_type == "w16a16" :
338+ if layer .topk_method == "noaux_tc" :
339+ gate_out , topk_weights , topk_idx = get_moe_scores (
340+ gate_out ,
341+ layer .n_group ,
342+ layer .topk_group ,
343+ layer .top_k ,
344+ layer .routed_scaling_factor ,
345+ layer .gate_correction_bias ,
346+ getattr (layer , "renormalize" , True ),
347+ )
348+ else :
349+ topk_idx , topk_weights = fastdeploy .model_executor .ops .gpu .moe_topk_select (
350+ gate_out ,
351+ layer .gate_correction_bias ,
352+ layer .top_k ,
353+ True , # apply_norm_weight
354+ False ,
355+ )
356+ topk_idx_i32 = topk_idx .astype (paddle .int32 )
357+ override_buffer_size = x .shape [0 ] * layer .top_k + layer .num_experts * (128 - 1 )
358+ (permute_input , permute_indices_per_token , dst_weights , _scale_out ) = ( # zipped_expertwise_rowmap
359+ paddle .nn .functional .moe_permute (
360+ hidden_states = x ,
361+ scale = None ,
362+ expert_routemap_topk = topk_idx_i32 ,
363+ expert_prob_topk = topk_weights ,
364+ num_experts = layer .num_experts ,
365+ tokens_per_expert = [],
366+ padding_alignment = 128 ,
367+ override_buffer_size = override_buffer_size ,
368+ )
369+ )
370+
371+ # Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert.
372+ token_nums_per_expert_cumsum = count_tokens_per_expert_func (topk_idx , layer .num_experts , True )[2 ].cast (
373+ paddle .int64
374+ )
375+ if topk_ids_hookfunc is not None :
376+ topk_ids_hookfunc (topk_ids = topk_idx )
377+
378+ ffn_out = self .compute_ffn (
379+ layer ,
380+ permute_input ,
381+ token_nums_per_expert_cumsum ,
382+ None , # expert_idx_per_token not needed for w16a16 without bias
383+ False ,
384+ - 1 ,
385+ None , # dequant_scale
386+ None , # max_tokens_per_expert
387+ )
388+
389+ fused_moe_out , _out_probs = paddle .nn .functional .moe_unpermute (
390+ hidden_states_unzipped = ffn_out ,
391+ zipped_expertwise_rowmap = permute_indices_per_token ,
392+ expert_routemap_topk = topk_idx_i32 ,
393+ token_prob_unzipped = dst_weights ,
394+ total_zipped_tokens = x .shape [0 ],
395+ num_experts = layer .num_experts ,
396+ using_weighted_combine = True ,
397+ )
398+ return fused_moe_out
399+
295400 if layer .topk_method == "noaux_tc" :
296401 gate_out , topk_weights , topk_idx = get_moe_scores (
297402 gate_out ,
@@ -401,7 +506,6 @@ def apply_tp(
401506 expert_idx_per_token = None
402507 else :
403508 expert_idx_per_token = expert_idx_per_token .cast ("int64" )
404-
405509 ffn_out = self .compute_ffn (
406510 layer ,
407511 permute_input ,
@@ -423,7 +527,6 @@ def apply_tp(
423527 norm_topk_prob = False if layer .topk_method == "noaux_tc" else True ,
424528 routed_scaling_factor = 1.0 ,
425529 )
426-
427530 return fused_moe_out
428531
429532
0 commit comments