Skip to content

fix: #1432 TE RMS Norm numerical Instability#1681

Open
JiwaniZakir wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
JiwaniZakir:fix/1432-te-rms-norm-numerical-instability
Open

fix: #1432 TE RMS Norm numerical Instability#1681
JiwaniZakir wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
JiwaniZakir:fix/1432-te-rms-norm-numerical-instability

Conversation

@JiwaniZakir
Copy link
Copy Markdown
Contributor

Closes #1432

What does this PR do ?

Fix numerical instability in TE RMSNorm by upcasting non-fp32 inputs to float32 before the kernel call, then casting outputs back to the original dtype.

Changelog

  • nemo_automodel/components/models/common/utils.py: In _make_lazy_te_patcher, _patched_rmsnorm_forward now captures input_dtype and, when the input is not already float32, calls _original_rmsnorm_forward with x.float() and casts the result back to input_dtype before returning.
  • tests/unit_tests/models/common/test_model_common_utils.py: Added TestFloat32RMSNorm test class with four tests covering: Float32RMSNorm weight dtype preservation, output dtype consistency through the fp32 computation path, numerical accuracy comparison between fp32-upcast and native bf16 norm against a float64 reference, and verification that the TE RMSNorm patch passes fp32 tensors to the underlying kernel.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

Additional Information


This PR was created with AI assistance (Claude). The changes were reviewed by quality gates and a critic model before submission.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Apr 4, 2026

Hi @JiwaniZakir thanks for making this, do you have a wandb that you can share showing the improved stability? Since this change is at a module level, it will affect all models using it, so i want to make sure we have a good understanding/review of the change. Thank you.

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Apr 5, 2026

/claude review

Comment on lines +230 to +247
"""The TE RMSNorm patch must upcast bf16 inputs to fp32 before the kernel call."""
captured = {}

def mock_original_forward(self_inner, x):
captured["dtype"] = x.dtype
return x # pass-through

instance = MagicMock()
x_bf16 = torch.randn(2, 4, dtype=torch.bfloat16)

# Replicate the patched forward logic from _make_lazy_te_patcher
input_dtype = x_bf16.dtype
if input_dtype != torch.float32:
result = mock_original_forward(instance, x_bf16.float())
else:
result = mock_original_forward(instance, x_bf16)

assert captured["dtype"] == torch.float32, "TE RMSNorm should receive fp32 inputs after upcast patch"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test re-implements the patch logic inline rather than invoking the actual _patched_rmsnorm_forward. If the real patch logic drifts, this test will still pass.

More concretely, it also doesn't verify the .to(input_dtype) cast-back — the result variable is never checked for dtype. Consider adding:

assert result.dtype == torch.bfloat16, "Result should be cast back to the original input dtype"

@ZhiyuLi-Nvidia
Copy link
Copy Markdown
Contributor

Hi, @JiwaniZakir

Thank you for the PR!

As for the next step, could you help us with

  • DCO action: CI/CD requires commit signature, could you do git commit -s --amend and later force push the branch again?
  • Validate PR tile one, we can resolve it by replacing Fix#1432: with fix:

Let us know what you think.

@akoumpa akoumpa changed the title Fix #1432: TE RMS Norm numerical Instability fix: #1432 TE RMS Norm numerical Instability Apr 7, 2026
@chtruong814 chtruong814 added waiting-for-customer waiting-on-customer Waiting on the original author to respond and removed waiting-for-customer labels Apr 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TE RMS Norm numerical Instability

5 participants