Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddleformers/transformers/minimax_m2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _add_layer_slice_config(prefix):
# FFN gate_up weights
if ffn_slice_fn is not None:
moe_intermediate_size = config.moe_intermediate_size
intermediate_size = config.intermediate_size

# Fused experts
param_name = f"{prefix}.mlp.experts.up_gate_proj.weight"
Expand All @@ -181,6 +182,13 @@ def _add_layer_slice_config(prefix):
ffn_slice_fn,
{"intermediate_size": moe_intermediate_size},
)
slice_config[f"{prefix}.mlp.grouped_gemm_experts.weight1"] = (
ffn_slice_fn,
{"intermediate_size": moe_intermediate_size},
)
# Common experts
param_name = f"{prefix}.mlp.up_gate_proj.weight"
slice_config[param_name] = (ffn_slice_fn, {"intermediate_size": intermediate_size})

# Routed experts (per-expert)
if hasattr(config, "n_routed_experts") and config.n_routed_experts > 0:
Expand All @@ -193,6 +201,7 @@ def _add_layer_slice_config(prefix):
# Fused MoE weights (grouped_gemm)
if moe_grouped_gemm and fused_moe_fn is not None:
slice_config[f"{prefix}.mlp.experts.down_proj.weight"] = (fused_moe_fn, {})
slice_config[f"{prefix}.mlp.grouped_gemm_experts.weight2"] = (fused_moe_fn, {})

# MLA weights
if use_mla and mla_slice_fn is not None:
Expand Down Expand Up @@ -244,6 +253,8 @@ def _add_layer_slice_config(prefix):
num_nextn_predict_layers = config.num_nextn_predict_layers if config.num_nextn_predict_layers else 0
for layer_idx in range(num_nextn_predict_layers):
_add_layer_slice_config(f"model.layers.{num_hidden_layers + layer_idx}")
for layer_idx in range(num_nextn_predict_layers):
_add_layer_slice_config(f"model.layers.{num_hidden_layers + layer_idx}.transformer_layer")

return slice_config

Expand Down
Loading