Skip to content

Commit cbf56ec

Browse files
committed
update muon slice func
1 parent f7a7f79 commit cbf56ec

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

paddleformers/transformers/minimax_m2/modeling.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def _add_layer_slice_config(prefix):
171171
# FFN gate_up weights
172172
if ffn_slice_fn is not None:
173173
moe_intermediate_size = config.moe_intermediate_size
174+
intermediate_size = config.intermediate_size
174175

175176
# Fused experts
176177
param_name = f"{prefix}.mlp.experts.up_gate_proj.weight"
@@ -181,6 +182,13 @@ def _add_layer_slice_config(prefix):
181182
ffn_slice_fn,
182183
{"intermediate_size": moe_intermediate_size},
183184
)
185+
slice_config[f"{prefix}.mlp.grouped_gemm_experts.weight1"] = (
186+
ffn_slice_fn,
187+
{"intermediate_size": moe_intermediate_size},
188+
)
189+
# Common experts
190+
param_name = f"{prefix}.mlp.up_gate_proj.weight"
191+
slice_config[param_name] = (ffn_slice_fn, {"intermediate_size": intermediate_size})
184192

185193
# Routed experts (per-expert)
186194
if hasattr(config, "n_routed_experts") and config.n_routed_experts > 0:
@@ -193,6 +201,7 @@ def _add_layer_slice_config(prefix):
193201
# Fused MoE weights (grouped_gemm)
194202
if moe_grouped_gemm and fused_moe_fn is not None:
195203
slice_config[f"{prefix}.mlp.experts.down_proj.weight"] = (fused_moe_fn, {})
204+
slice_config[f"{prefix}.mlp.grouped_gemm_experts.weight2"] = (fused_moe_fn, {})
196205

197206
# MLA weights
198207
if use_mla and mla_slice_fn is not None:
@@ -244,6 +253,8 @@ def _add_layer_slice_config(prefix):
244253
num_nextn_predict_layers = config.num_nextn_predict_layers if config.num_nextn_predict_layers else 0
245254
for layer_idx in range(num_nextn_predict_layers):
246255
_add_layer_slice_config(f"model.layers.{num_hidden_layers + layer_idx}")
256+
for layer_idx in range(num_nextn_predict_layers):
257+
_add_layer_slice_config(f"model.layers.{num_hidden_layers + layer_idx}.transformer_layer")
247258

248259
return slice_config
249260

0 commit comments

Comments
 (0)