Skip to content

fix masking error when using mlp_bias=True causing NaN during gradien…#3699

Open
snehalv2002 wants to merge 1 commit intomainfrom
mlp-bias-bug
Open

fix masking error when using mlp_bias=True causing NaN during gradien…#3699
snehalv2002 wants to merge 1 commit intomainfrom
mlp-bias-bug

Conversation

@snehalv2002
Copy link
Copy Markdown
Collaborator

@snehalv2002 snehalv2002 commented Apr 19, 2026

Description

Bug fix for b/497864549: [XL ML] NaN training loss for GPT-OSS SFT.

GPT-OSS is currently experiencing NaN training loss after one step when using expert_parallelism > 1. The bug was caused by lack of masking in the mlp bias inside of expert computation. Since no other models use the mlp bias they didn't experience this issue.

Issue

Currently, when expert_parallelism > 1 we introduce a buffer to store the output of ragged_all_to_all, which may contain padding along the token axis. If padding values are not masked after adding the MLP bias, JAX will include them in the gradient computation. jnp.where allows us to disconnect the padding values from the backwards graph during bias gradient calculation.

Tests

Ran pre_train.train on GPT_OSS with expert_parallelism > 1: https://paste.googleplex.com/6279941382602752

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 19, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Could you have a logits correctness check with EP>1 using GPT-OSS?

Comment thread src/maxtext/layers/moe.py
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w0 = jnp.where(mask[:, None], layer_w0, 0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is default mask defined without EP sharding?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants