Skip to content

[Qwen3MoE] Fix gradient alignment between PaddlePaddle and PyTorch/HF#4358

Open
a31413510 wants to merge 1 commit intodevelopfrom
fix/qwen3-moe-gradient-alignment
Open

[Qwen3MoE] Fix gradient alignment between PaddlePaddle and PyTorch/HF#4358
a31413510 wants to merge 1 commit intodevelopfrom
fix/qwen3-moe-gradient-alignment

Conversation

@a31413510
Copy link
Copy Markdown
Collaborator

Problem

When training Qwen3MoE with PaddleFormers, gradients and post-optimizer weights diverge from HuggingFace (PyTorch) starting from step 2. Three root causes were identified, all related to MoE expert activation behavior.

Root Causes & Fixes

Fix 1: hidden_states shared between routing and MLP paths (modeling.py)

hidden_states participates simultaneously in the routing computation (gate → softmax → topk → normalize) and all expert MLP subgraphs inside the expert loop. PaddlePaddle autograd accumulates gradients differently from PyTorch for such shared leaf nodes, causing up to 65536 ULP difference in gradients.

Fix: Add hidden_states_for_mlp = hidden_states + 0 before the expert loop to create an independent graph node for the MLP path.

Fix 2: fake_path produces grad=0 tensors for inactive experts (modeling.py)

When no tokens are routed to an expert, the original code ran expert_layer(fake_current_state * 0) to maintain graph connectivity. Although the input is zero, the backward pass produces grad=0 tensors (not None) for all expert weights. AdamW applies weight_decay on grad=0 (w *= 1 - lr*wd) but skips grad=None, so inactive expert weights diverge from HF by ~1 ULP per step. Over multiple steps this accumulates into significant GEMM output errors (~1280 ULP).

Fix: Replace the fake expert call with paddle.zeros([1, hidden_size]) added directly to final_hidden_states, bypassing all expert parameters so their gradients remain None.

Fix 3: clear_grad(set_to_zero=True) revives dead gradients (trainer.py)

The default clear_grad() (i.e., set_to_zero=True) converts all grad=None to zero tensors after the optimizer step. Experts active in step N but inactive in step N+1 then carry a zero-grad tensor into the next AdamW update, triggering weight_decay. PyTorch's zero_grad(set_to_none=True) (default since PyTorch 2.0) keeps grad=None, so HF skips those parameters entirely.

Fix: Change clear_grad() to clear_grad(set_to_zero=False), which releases gradient memory and keeps grad=None, matching PyTorch behavior.

Verification

Tested on Qwen3MoE with num_experts=120, num_hidden_layers=28. All 5 training steps, all 1182 post-optimizer weights are bit-exact (0 ULP) between PaddleFormers and HuggingFace/ms-swift.

Three fixes to achieve bit-exact gradient and weight alignment with HF
for Qwen3MoE SparseMoeBlock during training:

Fix 1 (modeling.py): Add `hidden_states_for_mlp = hidden_states + 0`
before the expert loop. hidden_states is shared between the routing
path (gate → softmax → topk → normalize) and multiple expert MLP
subgraphs. PaddlePaddle autograd accumulates gradients differently
from PyTorch for such shared nodes, causing up to 65536 ULP diff.
The +0 creates an independent graph node for the MLP path.

Fix 2 (modeling.py): Replace `expert_layer(fake_current_state * 0)`
with `paddle.zeros([1, hidden_size])` in the fake_path branch.
The original code passed a zero tensor through the expert's linear
layers, producing grad=0 tensors (not None) for all expert weights.
AdamW applies weight_decay on grad=0 (w *= 1 - lr*wd) but skips
grad=None, so inactive experts diverged from HF by ~1 ULP per step.

Fix 3 (trainer.py): Change `clear_grad()` to `clear_grad(set_to_zero=False)`.
The default set_to_zero=True converts all grad=None to zero tensors
after the optimizer step. Experts active in step N but inactive in
step N+1 then carry a zero-grad tensor into the next AdamW step,
triggering weight_decay. PyTorch zero_grad(set_to_none=True) keeps
grad=None, so HF skips those parameters entirely.

Verified: all 5 training steps, 1182 post-optimizer weights are
bit-exact (0 ULP) between PaddleFormers and HuggingFace/ms-swift.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Apr 25, 2026

Thanks for your contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant