@@ -207,67 +207,6 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
207207 return ffn_out
208208
209209
210- def moe_topk_select (
211- gating_output : paddle .Tensor ,
212- n_group : int ,
213- topk_group : int ,
214- top_k : int ,
215- routed_scaling_factor : float ,
216- e_score_correction_bias : paddle .Tensor ,
217- renormalize : bool = False ,
218- ):
219- """
220- Topk selection using paddle PHI topk API.
221-
222- Args:
223- gating_output: gate output logits, shape [seq_len, n_experts]
224- n_group: number of expert groups
225- topk_group: number of top-k groups to select
226- top_k: number of top experts per token
227- routed_scaling_factor: scaling factor for routed experts
228- e_score_correction_bias: bias for expert selection
229- renormalize: whether to renormalize topk probabilities
230-
231- Returns:
232- topk_weights: normalized topk probabilities, shape [seq_len, top_k]
233- topk_ids: topk expert indices, shape [seq_len, top_k]
234- """
235- # compute gate probs via sigmoid
236- gate_probs = paddle .nn .functional .sigmoid (gating_output )
237- # probs_for_choice includes correction bias for topk selection
238- probs_for_choice = gate_probs + e_score_correction_bias if e_score_correction_bias is not None else gate_probs
239- # group-based topk selection
240- n_group = n_group if n_group > 0 else 1
241- topk_group = topk_group if topk_group > 0 else 1
242- if n_group > 1 and topk_group < n_group :
243- seq_length , n_experts = probs_for_choice .shape
244- group_scores = (
245- probs_for_choice .reshape ([seq_length , n_group , - 1 ]).topk (2 , axis = - 1 )[0 ].sum (axis = - 1 )
246- ) # [seq_len, n_group]
247- group_idx = paddle .topk (group_scores , k = topk_group , axis = - 1 , sorted = True )[1 ] # [seq_len, topk_group]
248- group_mask = paddle .sum (
249- paddle .nn .functional .one_hot (group_idx , num_classes = n_group ).cast (group_scores .dtype ),
250- axis = 1 , # Sum over topk_group dimension -> [seq_len, n_group]
251- )
252- score_mask = (
253- group_mask .unsqueeze (- 1 ).expand ([seq_length , n_group , n_experts // n_group ]).reshape ([seq_length , - 1 ])
254- ) # [seq_len, n_experts]
255- probs_for_choice = probs_for_choice .masked_fill (~ score_mask .astype (paddle .bool ), float ("-inf" ))
256-
257- _ , topk_ids = paddle .topk (probs_for_choice , top_k , axis = - 1 )
258- topk_weights = paddle .index_sample (gate_probs , topk_ids )
259-
260- # normalize combine weights
261- if renormalize :
262- topk_weights = topk_weights / paddle .clip (topk_weights .sum (- 1 , keepdim = True ), min = 1e-12 )
263-
264- # apply routed scaling factor
265- if routed_scaling_factor :
266- topk_weights = topk_weights * routed_scaling_factor
267-
268- return topk_weights , topk_ids
269-
270-
271210class DeepGemmFusedMoeMethod (MoEMethodBase ):
272211 """
273212 DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend.
@@ -403,22 +342,7 @@ def apply_ep_prefill(
403342 hidden_size = x .shape [1 ]
404343
405344 # 1. Select topk experts and weights
406- if (
407- fastdeploy .envs .FD_USE_PHI_MOE_TOPK
408- and layer .redundant_table_manger is None
409- and layer .topk_method == "noaux_tc"
410- ):
411- topk_weights , topk_idx = moe_topk_select (
412- gate_out ,
413- layer .n_group ,
414- layer .topk_group ,
415- layer .top_k ,
416- layer .routed_scaling_factor ,
417- layer .gate_correction_bias ,
418- getattr (layer , "renormalize" , True ),
419- )
420- else :
421- topk_idx , topk_weights = self .ep_prefill_runner .moe_select (layer , gate_out )
345+ topk_idx , topk_weights = self .ep_prefill_runner .moe_select (layer , gate_out )
422346
423347 if topk_ids_hookfunc is not None :
424348 topk_ids_hookfunc (topk_ids = topk_idx )
@@ -820,28 +744,16 @@ def apply_tp(
820744 gate_out = gate_out .cast ("float32" )
821745
822746 if layer .topk_method == "noaux_tc" :
823-
824- if not fastdeploy .envs .FD_USE_PHI_MOE_TOPK :
825- _ , topk_weights , topk_ids = fastdeploy .model_executor .layers .moe .moe .get_moe_scores (
826- gate_out ,
827- layer .n_group ,
828- layer .topk_group ,
829- layer .top_k ,
830- layer .routed_scaling_factor ,
831- layer .gate_correction_bias ,
832- getattr (layer , "renormalize" , True ),
833- )
834- else :
835- topk_weights , topk_ids = moe_topk_select (
836- gate_out ,
837- layer .n_group ,
838- layer .topk_group ,
839- layer .top_k ,
840- layer .routed_scaling_factor ,
841- layer .gate_correction_bias ,
842- getattr (layer , "renormalize" , True ),
843- )
844-
747+ _ , topk_weights , topk_ids = fastdeploy .model_executor .layers .moe .moe .get_moe_scores (
748+ gate_out ,
749+ layer .n_group ,
750+ layer .topk_group ,
751+ layer .top_k ,
752+ layer .routed_scaling_factor ,
753+ layer .gate_correction_bias ,
754+ getattr (layer , "renormalize" , True ),
755+ topk_reduce_func = getattr (layer , "topk_reduce_func" , None ),
756+ )
845757 else :
846758 topk_ids , topk_weights = fastdeploy .model_executor .ops .gpu .moe_topk_select (
847759 gate_out ,
0 commit comments