2828from .fused_moe_backend_base import UnquantizedFusedMoEMethod
2929
3030if current_platform .is_cuda ():
31- from fastdeploy .model_executor .ops .gpu import moe_expert_dispatch , moe_expert_reduce
31+ from fastdeploy .model_executor .ops .gpu import (
32+ count_tokens_per_expert_func ,
33+ moe_expert_dispatch ,
34+ moe_expert_reduce ,
35+ )
3236
3337 try :
3438 from fastdeploy .model_executor .ops .gpu import (
@@ -126,14 +130,15 @@ def apply_ep_prefill(
126130 # 1. Select topk experts and weights
127131 topk_idx , topk_weights = self .ep_prefill_runner .moe_select (layer , gate_out )
128132 # 2. EP Dispatch
133+ dispatch_kwargs = {"expert_alignment" : 128 } if fastdeploy .envs .FD_USE_PHI_MOE_PERMUTE else {}
129134 (
130135 recv_x ,
131136 recv_topk_idx ,
132137 recv_topk_weights ,
133138 recv_num_tokens_per_expert_list ,
134139 handle ,
135140 event ,
136- ) = self .ep_prefill_runner .dispatch (x , topk_idx , topk_weights )
141+ ) = self .ep_prefill_runner .dispatch (x , topk_idx , topk_weights , ** dispatch_kwargs )
137142
138143 if topk_ids_hookfunc is not None :
139144 topk_ids_hookfunc (topk_ids = topk_idx )
@@ -146,54 +151,91 @@ def apply_ep_prefill(
146151 # 3. Compute ffn
147152 if token_all_num > 0 :
148153 logger .debug (f"token_all_num { token_all_num } " )
149- (
150- permute_input ,
151- permute_indices_per_token ,
152- recv_num_tokens_per_expert_list_cumsum ,
153- dst_weights ,
154- dst_indices ,
155- cumsum_idx_gpu ,
156- expert_idx_per_token ,
157- dequant_scale ,
158- ) = fastdeploy .model_executor .ops .gpu .ep_moe_expert_dispatch (
159- recv_x ,
160- recv_topk_idx ,
161- recv_topk_weights ,
162- (layer .up_gate_proj_in_scale if hasattr (layer , "up_gate_proj_in_scale" ) else None ),
163- recv_num_tokens_per_expert_list ,
164- token_all_num ,
165- self .moe_quant_type ,
166- )
167- if not layer .with_bias and self .moe_quant_type != "w4a8" and self .moe_quant_type != "w4afp8" :
168- # only w4a8 and w4afp8 need expert_idx_per_token
169- # Other need not this tensor, so we make it None.
170- expert_idx_per_token = None
154+
155+ if fastdeploy .envs .FD_USE_PHI_MOE_PERMUTE and self .moe_quant_type == "w16a16" :
156+ # --- moe_permute / moe_unpermute path ---
157+ recv_topk_idx_i32 = recv_topk_idx .astype (paddle .int32 )
158+ (permute_input , permute_indices_per_token , dst_weights , _scale_out ) = paddle .nn .functional .moe_permute (
159+ hidden_states = recv_x ,
160+ scale = None ,
161+ expert_routemap_topk = recv_topk_idx_i32 ,
162+ expert_prob_topk = recv_topk_weights ,
163+ num_experts = layer .num_local_experts ,
164+ tokens_per_expert = [],
165+ padding_alignment = 128 ,
166+ override_buffer_size = token_all_num ,
167+ )
168+
169+ token_nums_per_expert_cumsum = count_tokens_per_expert_func (
170+ recv_topk_idx , layer .num_local_experts , True
171+ )[2 ].cast (paddle .int64 )
172+ ffn_out = self .compute_ffn (
173+ layer ,
174+ permute_input ,
175+ token_nums_per_expert_cumsum ,
176+ None ,
177+ False ,
178+ - 1 ,
179+ None ,
180+ None ,
181+ )
182+
183+ tmp_ffn_out , _out_probs = paddle .nn .functional .moe_unpermute (
184+ hidden_states_unzipped = ffn_out ,
185+ zipped_expertwise_rowmap = permute_indices_per_token ,
186+ expert_routemap_topk = recv_topk_idx_i32 ,
187+ token_prob_unzipped = dst_weights ,
188+ total_zipped_tokens = recv_x .shape [0 ],
189+ num_experts = layer .num_local_experts ,
190+ using_weighted_combine = True ,
191+ )
171192 else :
172- expert_idx_per_token = expert_idx_per_token .cast ("int64" )
193+ # --- original ep_moe_expert_dispatch / combine path ---
194+ (
195+ permute_input ,
196+ permute_indices_per_token ,
197+ recv_num_tokens_per_expert_list_cumsum ,
198+ dst_weights ,
199+ dst_indices ,
200+ cumsum_idx_gpu ,
201+ expert_idx_per_token ,
202+ dequant_scale ,
203+ ) = fastdeploy .model_executor .ops .gpu .ep_moe_expert_dispatch (
204+ recv_x ,
205+ recv_topk_idx ,
206+ recv_topk_weights ,
207+ (layer .up_gate_proj_in_scale if hasattr (layer , "up_gate_proj_in_scale" ) else None ),
208+ recv_num_tokens_per_expert_list ,
209+ token_all_num ,
210+ self .moe_quant_type ,
211+ )
212+ if not layer .with_bias and self .moe_quant_type != "w4a8" and self .moe_quant_type != "w4afp8" :
213+ expert_idx_per_token = None
214+ else :
215+ expert_idx_per_token = expert_idx_per_token .cast ("int64" )
173216
174- if hasattr (layer , "up_gate_proj_in_scale" ):
175- dequant_scale = None
217+ if hasattr (layer , "up_gate_proj_in_scale" ):
218+ dequant_scale = None
176219
177- ffn_out = self .compute_ffn (
178- layer ,
179- permute_input ,
180- recv_num_tokens_per_expert_list_cumsum ,
181- expert_idx_per_token ,
182- False ,
183- - 1 ,
184- dequant_scale ,
185- )
220+ ffn_out = self .compute_ffn (
221+ layer ,
222+ permute_input ,
223+ recv_num_tokens_per_expert_list_cumsum ,
224+ expert_idx_per_token ,
225+ False ,
226+ - 1 ,
227+ dequant_scale ,
228+ )
186229
187- # prmt back per rank
188- tmp_ffn_out = fastdeploy .model_executor .ops .gpu .ep_moe_expert_combine (
189- ffn_out ,
190- dst_weights ,
191- permute_indices_per_token ,
192- dst_indices ,
193- None , # down_proj_bias,
194- False , # norm_topk_prob
195- 1.0 ,
196- )
230+ tmp_ffn_out = fastdeploy .model_executor .ops .gpu .ep_moe_expert_combine (
231+ ffn_out ,
232+ dst_weights ,
233+ permute_indices_per_token ,
234+ dst_indices ,
235+ None , # down_proj_bias,
236+ False , # norm_topk_prob
237+ 1.0 ,
238+ )
197239 else :
198240 tmp_ffn_out = recv_x
199241
@@ -276,6 +318,69 @@ def apply_tp(
276318 """
277319 gate_out = gate (x )
278320 gate_out = gate_out .cast ("float32" )
321+ if fastdeploy .envs .FD_USE_PHI_MOE_PERMUTE and self .moe_quant_type == "w16a16" :
322+ if layer .topk_method == "noaux_tc" :
323+ gate_out , topk_weights , topk_idx = get_moe_scores (
324+ gate_out ,
325+ layer .n_group ,
326+ layer .topk_group ,
327+ layer .top_k ,
328+ layer .routed_scaling_factor ,
329+ layer .gate_correction_bias ,
330+ getattr (layer , "renormalize" , True ),
331+ )
332+ else :
333+ topk_idx , topk_weights = fastdeploy .model_executor .ops .gpu .moe_topk_select (
334+ gate_out ,
335+ layer .gate_correction_bias ,
336+ layer .top_k ,
337+ True , # apply_norm_weight
338+ False ,
339+ )
340+ topk_idx_i32 = topk_idx .astype (paddle .int32 )
341+ override_buffer_size = x .shape [0 ] * layer .top_k + layer .num_experts * (128 - 1 )
342+ (permute_input , permute_indices_per_token , dst_weights , _scale_out ) = ( # zipped_expertwise_rowmap
343+ paddle .nn .functional .moe_permute (
344+ hidden_states = x ,
345+ scale = None ,
346+ expert_routemap_topk = topk_idx_i32 ,
347+ expert_prob_topk = topk_weights ,
348+ num_experts = layer .num_experts ,
349+ tokens_per_expert = [],
350+ padding_alignment = 128 ,
351+ override_buffer_size = override_buffer_size ,
352+ )
353+ )
354+
355+ # Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert.
356+ token_nums_per_expert_cumsum = count_tokens_per_expert_func (topk_idx , layer .num_experts , True )[2 ].cast (
357+ paddle .int64
358+ )
359+ if topk_ids_hookfunc is not None :
360+ topk_ids_hookfunc (topk_ids = topk_idx )
361+
362+ ffn_out = self .compute_ffn (
363+ layer ,
364+ permute_input ,
365+ token_nums_per_expert_cumsum ,
366+ None , # expert_idx_per_token not needed for w16a16 without bias
367+ False ,
368+ - 1 ,
369+ None , # dequant_scale
370+ None , # max_tokens_per_expert
371+ )
372+
373+ fused_moe_out , _out_probs = paddle .nn .functional .moe_unpermute (
374+ hidden_states_unzipped = ffn_out ,
375+ zipped_expertwise_rowmap = permute_indices_per_token ,
376+ expert_routemap_topk = topk_idx_i32 ,
377+ token_prob_unzipped = dst_weights ,
378+ total_zipped_tokens = x .shape [0 ],
379+ num_experts = layer .num_experts ,
380+ using_weighted_combine = True ,
381+ )
382+ return fused_moe_out
383+
279384 if layer .topk_method == "noaux_tc" :
280385 gate_out , topk_weights , topk_idx = get_moe_scores (
281386 gate_out ,
@@ -287,6 +392,7 @@ def apply_tp(
287392 getattr (layer , "renormalize" , True ),
288393 topk_reduce_func = getattr (layer , "topk_reduce_func" , None ),
289394 )
395+
290396 (
291397 permute_input ,
292398 token_nums_per_expert ,
@@ -341,7 +447,6 @@ def apply_tp(
341447 expert_idx_per_token = None
342448 else :
343449 expert_idx_per_token = expert_idx_per_token .cast ("int64" )
344-
345450 ffn_out = self .compute_ffn (
346451 layer ,
347452 permute_input ,
@@ -363,7 +468,6 @@ def apply_tp(
363468 norm_topk_prob = False if layer .topk_method == "noaux_tc" else True ,
364469 routed_scaling_factor = 1.0 ,
365470 )
366-
367471 return fused_moe_out
368472
369473
0 commit comments