Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions fastdeploy/model_executor/layers/moe/ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import fastdeploy
from fastdeploy import envs
from fastdeploy.config import MoEPhase
from fastdeploy.platforms import current_platform
from fastdeploy.utils import singleton


Expand Down Expand Up @@ -531,6 +532,9 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
if layer.topk_method == "noaux_tc":
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores

use_fused = (
layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda()
)
score, topk_weights, topk_idx = get_moe_scores(
gate_out,
layer.n_group,
Expand All @@ -540,6 +544,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
use_fused_cast=use_fused,
)
else:
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
Expand Down
1 change: 1 addition & 0 deletions tests/model_executor/test_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def fake_get_moe_scores(*_args, **_kwargs):
routed_scaling_factor=1.0,
gate_correction_bias=None,
renormalize=False,
fd_config=SimpleNamespace(scheduler_config=SimpleNamespace(enable_moe_scores_elementwise_fuse=False)),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 测试仅覆盖 enable_moe_scores_elementwise_fuse=False 的路径,use_fused=Truegrouped_topk 的 fused kernel 路径无测试覆盖。

建议增加一个 enable_moe_scores_elementwise_fuse=True 的测试用例,验证 fused 路径与原始路径的路由结果(topk_idxtopk_weights)在数值上保持一致,以防回归。

)
gate_out = paddle.randn([1, 4], dtype="float32")

Expand Down
Loading