Skip to content

Commit 9c65655

Browse files
authored
[Cherry-Pick][RL] support moe-topk use topk_reduce_func #7218 (#7256)
* support moe-topk use topk_reduce_func * fix ep error * fix ut * fix ut
1 parent 0181884 commit 9c65655

7 files changed

Lines changed: 66 additions & 112 deletions

File tree

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
509509
expert_in_rank_num_list=expert_in_rank_num_list,
510510
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
511511
redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1,
512+
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
512513
)
513514
else:
514515
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
@@ -534,6 +535,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
534535
layer.routed_scaling_factor,
535536
layer.gate_correction_bias,
536537
getattr(layer, "renormalize", True),
538+
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
537539
)
538540
else:
539541
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(

fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def apply_tp(
285285
layer.routed_scaling_factor,
286286
layer.gate_correction_bias,
287287
getattr(layer, "renormalize", True),
288+
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
288289
)
289290
(
290291
permute_input,

fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py

Lines changed: 11 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -207,67 +207,6 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
207207
return ffn_out
208208

209209

210-
def moe_topk_select(
211-
gating_output: paddle.Tensor,
212-
n_group: int,
213-
topk_group: int,
214-
top_k: int,
215-
routed_scaling_factor: float,
216-
e_score_correction_bias: paddle.Tensor,
217-
renormalize: bool = False,
218-
):
219-
"""
220-
Topk selection using paddle PHI topk API.
221-
222-
Args:
223-
gating_output: gate output logits, shape [seq_len, n_experts]
224-
n_group: number of expert groups
225-
topk_group: number of top-k groups to select
226-
top_k: number of top experts per token
227-
routed_scaling_factor: scaling factor for routed experts
228-
e_score_correction_bias: bias for expert selection
229-
renormalize: whether to renormalize topk probabilities
230-
231-
Returns:
232-
topk_weights: normalized topk probabilities, shape [seq_len, top_k]
233-
topk_ids: topk expert indices, shape [seq_len, top_k]
234-
"""
235-
# compute gate probs via sigmoid
236-
gate_probs = paddle.nn.functional.sigmoid(gating_output)
237-
# probs_for_choice includes correction bias for topk selection
238-
probs_for_choice = gate_probs + e_score_correction_bias if e_score_correction_bias is not None else gate_probs
239-
# group-based topk selection
240-
n_group = n_group if n_group > 0 else 1
241-
topk_group = topk_group if topk_group > 0 else 1
242-
if n_group > 1 and topk_group < n_group:
243-
seq_length, n_experts = probs_for_choice.shape
244-
group_scores = (
245-
probs_for_choice.reshape([seq_length, n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
246-
) # [seq_len, n_group]
247-
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [seq_len, topk_group]
248-
group_mask = paddle.sum(
249-
paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype),
250-
axis=1, # Sum over topk_group dimension -> [seq_len, n_group]
251-
)
252-
score_mask = (
253-
group_mask.unsqueeze(-1).expand([seq_length, n_group, n_experts // n_group]).reshape([seq_length, -1])
254-
) # [seq_len, n_experts]
255-
probs_for_choice = probs_for_choice.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))
256-
257-
_, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1)
258-
topk_weights = paddle.index_sample(gate_probs, topk_ids)
259-
260-
# normalize combine weights
261-
if renormalize:
262-
topk_weights = topk_weights / paddle.clip(topk_weights.sum(-1, keepdim=True), min=1e-12)
263-
264-
# apply routed scaling factor
265-
if routed_scaling_factor:
266-
topk_weights = topk_weights * routed_scaling_factor
267-
268-
return topk_weights, topk_ids
269-
270-
271210
class DeepGemmFusedMoeMethod(MoEMethodBase):
272211
"""
273212
DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend.
@@ -403,22 +342,7 @@ def apply_ep_prefill(
403342
hidden_size = x.shape[1]
404343

405344
# 1. Select topk experts and weights
406-
if (
407-
fastdeploy.envs.FD_USE_PHI_MOE_TOPK
408-
and layer.redundant_table_manger is None
409-
and layer.topk_method == "noaux_tc"
410-
):
411-
topk_weights, topk_idx = moe_topk_select(
412-
gate_out,
413-
layer.n_group,
414-
layer.topk_group,
415-
layer.top_k,
416-
layer.routed_scaling_factor,
417-
layer.gate_correction_bias,
418-
getattr(layer, "renormalize", True),
419-
)
420-
else:
421-
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
345+
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
422346

423347
if topk_ids_hookfunc is not None:
424348
topk_ids_hookfunc(topk_ids=topk_idx)
@@ -820,28 +744,16 @@ def apply_tp(
820744
gate_out = gate_out.cast("float32")
821745

822746
if layer.topk_method == "noaux_tc":
823-
824-
if not fastdeploy.envs.FD_USE_PHI_MOE_TOPK:
825-
_, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores(
826-
gate_out,
827-
layer.n_group,
828-
layer.topk_group,
829-
layer.top_k,
830-
layer.routed_scaling_factor,
831-
layer.gate_correction_bias,
832-
getattr(layer, "renormalize", True),
833-
)
834-
else:
835-
topk_weights, topk_ids = moe_topk_select(
836-
gate_out,
837-
layer.n_group,
838-
layer.topk_group,
839-
layer.top_k,
840-
layer.routed_scaling_factor,
841-
layer.gate_correction_bias,
842-
getattr(layer, "renormalize", True),
843-
)
844-
747+
_, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores(
748+
gate_out,
749+
layer.n_group,
750+
layer.topk_group,
751+
layer.top_k,
752+
layer.routed_scaling_factor,
753+
layer.gate_correction_bias,
754+
getattr(layer, "renormalize", True),
755+
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
756+
)
845757
else:
846758
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
847759
gate_out,

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,22 @@ def get_moe_scores(
9090
expert_in_rank_num_list: paddle.Tensor = None,
9191
tokens_per_expert_stats_list: paddle.Tensor = None,
9292
redundant_ep_rank_num_plus_one: int = 1,
93+
topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True) + 1e-20,
9394
) -> paddle.Tensor:
9495
"""
9596
compute moe scores using e_score_correction_bias.
9697
"""
9798
scores = paddle.nn.functional.sigmoid(gating_output)
9899
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
99100
scores_with_bias = scores + e_score_correction_bias
101+
102+
if envs.FD_USE_PHI_MOE_TOPK:
103+
# calculate renormalize and routed_scaling_factor value outside the noaux_tc
104+
original_renormalize = renormalize
105+
original_routed_scaling_factor = routed_scaling_factor
106+
renormalize = False
107+
routed_scaling_factor = 1.0
108+
100109
if expert_id_to_ep_rank_array is None:
101110
scores, topk_values, topk_idx = noaux_tc(
102111
scores,
@@ -123,6 +132,16 @@ def get_moe_scores(
123132
routed_scaling_factor,
124133
redundant_ep_rank_num_plus_one,
125134
)
135+
if envs.FD_USE_PHI_MOE_TOPK:
136+
if original_renormalize:
137+
if topk_reduce_func is not None:
138+
topk_values = topk_values / topk_reduce_func(topk_values)
139+
else:
140+
# 使用默认的 sum + epsilon
141+
topk_values = topk_values / (topk_values.sum(axis=-1, keepdim=True) + 1e-20)
142+
143+
if original_routed_scaling_factor != 1.0:
144+
topk_values *= original_routed_scaling_factor
126145
return scores, topk_values, topk_idx
127146

128147

@@ -152,6 +171,8 @@ def __init__(
152171
with_bias: bool = False,
153172
activation="swiglu",
154173
model_format: Optional[str] = None,
174+
topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True)
175+
+ 1e-20, # only used when FD_USE_PHI_MOE_TOPK=1, default is same as noaux_tc kernel
155176
):
156177
"""
157178
Initialize the Moe layer with given parameters.
@@ -197,6 +218,7 @@ def __init__(
197218
self.moe_tag = moe_tag
198219
self.with_bias = with_bias
199220
self.activation = activation
221+
self.topk_reduce_func = topk_reduce_func
200222

201223
if self.ep_size > 1:
202224
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts

fastdeploy/model_executor/models/glm4_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def __init__(
182182
layer_idx=layer_id,
183183
gate_correction_bias=self.gate.e_score_correction_bias,
184184
weight_key_map=weight_key_map,
185+
topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20,
185186
)
186187

187188
if self.n_shared_experts > 0:

tests/layers/test_fused_moe_cutlass_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,9 @@ def combine(self, ffn_out, topk_idx, topk_weights, handle, quant_group_size=-1):
388388
np.testing.assert_allclose(out.numpy(), np.full((1, 2), 5.0))
389389

390390
def test_apply_tp_with_dispatch_and_reduce(self, monkeypatch):
391-
def fake_get_moe_scores(gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize):
391+
def fake_get_moe_scores(
392+
gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize, topk_reduce_func=None
393+
):
392394
return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]])
393395

394396
def fake_dispatch(*args, **kwargs):

tests/operators/test_noaux_tc_redundant.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
1+
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import unittest
16+
from unittest import mock
217

318
import paddle
419

5-
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import (
6-
moe_topk_select,
7-
)
820
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
921

1022

@@ -135,15 +147,17 @@ def test_group_topk_using_phi_topk(self):
135147
e_score_correction_bias=e_score_correction_bias,
136148
)
137149

138-
topk_values, topk_idx = moe_topk_select(
139-
gating_output=gating_output,
140-
n_group=n_group,
141-
topk_group=topk_group,
142-
top_k=top_k,
143-
routed_scaling_factor=routed_scaling_factor,
144-
e_score_correction_bias=e_score_correction_bias,
145-
renormalize=renormalize,
146-
)
150+
with mock.patch.dict("os.environ", {"FD_USE_PHI_MOE_TOPK": "1"}):
151+
new_score, topk_values, topk_idx = get_moe_scores(
152+
gating_output=gating_output,
153+
n_group=n_group,
154+
topk_group=topk_group,
155+
top_k=top_k,
156+
routed_scaling_factor=routed_scaling_factor,
157+
e_score_correction_bias=e_score_correction_bias,
158+
renormalize=renormalize,
159+
topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20,
160+
)
147161

148162
equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item()
149163
equal_topk_ids = paddle.allclose(

0 commit comments

Comments
 (0)