Skip to content

Commit b336db7

Browse files
authored
fix_moe_learable-score111 (#7903)
Fix routed_scaling_factor_learnable not taking effect in cutlass backend apply_tp
1 parent e56d9ff commit b336db7

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def apply_tp(
446446
gate_out = gate_out.cast("float32")
447447
if fc1_latent_proj is not None:
448448
x = fc1_latent_proj(x)
449-
gate_out, topk_weights, topk_idx = get_moe_scores(
449+
gate_out, _, __ = get_moe_scores(
450450
gate_out,
451451
layer.n_group,
452452
layer.topk_group,
@@ -458,11 +458,6 @@ def apply_tp(
458458
use_fused_cast=use_fused,
459459
)
460460

461-
if layer.routed_scaling_factor_learnable:
462-
safe_topk_indices = paddle.clip(topk_idx, min=0)
463-
gathered_scales = F.embedding(safe_topk_indices, layer.per_expert_scale.unsqueeze(1)).squeeze(-1)
464-
topk_weights = topk_weights * gathered_scales
465-
466461
(
467462
permute_input,
468463
token_nums_per_expert,
@@ -484,6 +479,12 @@ def apply_tp(
484479
self.moe_quant_type,
485480
topk_only_mode=True,
486481
)
482+
483+
if layer.routed_scaling_factor_learnable:
484+
safe_topk_indices = paddle.clip(topk_idx, min=0)
485+
gathered_scales = F.embedding(safe_topk_indices, layer.per_expert_scale.unsqueeze(1)).squeeze(-1)
486+
topk_weights = topk_weights * gathered_scales
487+
487488
else:
488489
gate_out = gate_out.cast("float32")
489490
if fc1_latent_proj is not None:

0 commit comments

Comments
 (0)