File tree Expand file tree Collapse file tree
fastdeploy/model_executor/layers/moe Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -446,7 +446,7 @@ def apply_tp(
446446 gate_out = gate_out .cast ("float32" )
447447 if fc1_latent_proj is not None :
448448 x = fc1_latent_proj (x )
449- gate_out , topk_weights , topk_idx = get_moe_scores (
449+ gate_out , _ , __ = get_moe_scores (
450450 gate_out ,
451451 layer .n_group ,
452452 layer .topk_group ,
@@ -458,11 +458,6 @@ def apply_tp(
458458 use_fused_cast = use_fused ,
459459 )
460460
461- if layer .routed_scaling_factor_learnable :
462- safe_topk_indices = paddle .clip (topk_idx , min = 0 )
463- gathered_scales = F .embedding (safe_topk_indices , layer .per_expert_scale .unsqueeze (1 )).squeeze (- 1 )
464- topk_weights = topk_weights * gathered_scales
465-
466461 (
467462 permute_input ,
468463 token_nums_per_expert ,
@@ -484,6 +479,12 @@ def apply_tp(
484479 self .moe_quant_type ,
485480 topk_only_mode = True ,
486481 )
482+
483+ if layer .routed_scaling_factor_learnable :
484+ safe_topk_indices = paddle .clip (topk_idx , min = 0 )
485+ gathered_scales = F .embedding (safe_topk_indices , layer .per_expert_scale .unsqueeze (1 )).squeeze (- 1 )
486+ topk_weights = topk_weights * gathered_scales
487+
487488 else :
488489 gate_out = gate_out .cast ("float32" )
489490 if fc1_latent_proj is not None :
You can’t perform that action at this time.
0 commit comments