Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion paddleformers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,7 +1766,11 @@ def optimizer_step(self, args, model, parameters_list=None):
for buffer in buffers:
buffer._clear_grad_storage()
else:
self.optimizer.clear_grad()
# Use set_to_zero=False so parameters not activated in the next step
# keep grad=None rather than a zero tensor. With set_to_zero=True (default),
# AdamW applies weight_decay to inactive MoE experts (grad=0 tensor != grad=None),
# causing weight divergence from PyTorch/HF which skips grad=None parameters.
self.optimizer.clear_grad(set_to_zero=False)

def _get_meshes_for_loader(self):
return self.global_mesh.get_mesh_with_dim("pp")[0]
Expand Down
17 changes: 13 additions & 4 deletions paddleformers/transformers/qwen3_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
)
# [num_experts, topk, bs*seq]
tokens_per_expert = expert_mask.reshape([expert_mask.shape[0], -1]).sum(axis=-1)
# Fix 1: create a separate graph node for MLP indexing to avoid PaddlePaddle autograd
# gradient accumulation mismatch. hidden_states is shared by both the routing path
# (gate → softmax → topk → normalize) and multiple expert MLP subgraphs. Without this
# separation, PaddlePaddle accumulates gradients differently from PyTorch, causing diffs.
hidden_states_for_mlp = hidden_states + 0
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
Expand All @@ -413,16 +418,20 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
if tokens_per_expert[expert_idx] <= 0.1:
if self.training and paddle.is_grad_enabled():
# Fix 2: do NOT pass through expert_layer parameters for the fake path.
# The original code ran expert_layer(x*0), which produced grad=0 tensors
# for all expert weights. AdamW still applies weight_decay on grad=0
# (unlike grad=None), causing inactive expert weights to diverge from HF
# after each step. Instead, directly add zeros to keep expert grad=None.
fake_top_x = paddle.zeros(1, dtype=paddle.int64)
fakse_current_state = hidden_states[fake_top_x, None].reshape([-1, hidden_states.shape[-1]])
fake_state = expert_layer(fakse_current_state * 0)
fake_state = paddle.zeros([1, hidden_states.shape[-1]], dtype=hidden_states.dtype)
final_hidden_states.index_add_(
index=fake_top_x, axis=0, value=fake_state.to(hidden_states.dtype)
index=fake_top_x, axis=0, value=fake_state
)
else:
continue
else:
current_state = hidden_states[idx, None].reshape([-1, hidden_states.shape[-1]])
current_state = hidden_states_for_mlp[idx, None].reshape([-1, hidden_states.shape[-1]])
current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1)
final_hidden_states.index_add_(
index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype)
Expand Down
Loading