Skip to content

fix: remove double RMSNorm on B/C in MambaInnerFn.backward at checkpoint_lvl=1#939

Open
Chessing234 wants to merge 1 commit into
state-spaces:mainfrom
Chessing234:fix/double-rmsnorm-b-c-backward
Open

fix: remove double RMSNorm on B/C in MambaInnerFn.backward at checkpoint_lvl=1#939
Chessing234 wants to merge 1 commit into
state-spaces:mainfrom
Chessing234:fix/double-rmsnorm-b-c-backward

Conversation

@Chessing234
Copy link
Copy Markdown
Contributor

Bug

MambaInnerFn.backward applies RMSNorm to B and C a second time when checkpoint_lvl == 1, silently corrupting gradients.

Forward pass (selective_scan_interface.py, around line 250):

if b_rms_weight is not None:
    B = rms_norm_forward(B, b_rms_weight, ...)   # B is now post-norm
if c_rms_weight is not None:
    C = rms_norm_forward(C, c_rms_weight, ...)   # C is now post-norm

Then ctx.save_for_backward(..., B, C, ...) saves post-norm B and C.

Only conv1d_out and delta are set to None before saving:

if checkpoint_lvl >= 1:
    conv1d_out, delta = None, None   # will be recomputed in backward

B and C are not set to None — they are saved as already-normalized tensors.

Backward pass (checkpoint_lvl == 1):

  • delta is correctly recomputed from scratch (was None), so re-applying dt_rms is correct.
  • B and C come from ctx.saved_tensors already normalized. The removed blocks applied RMSNorm a second time, making the effective normalization rms_norm(rms_norm(B)) instead of rms_norm(B).

This causes incorrect gradients for any model using b_rms_weight or c_rms_weight with activation checkpointing (checkpoint_lvl=1), such as Mamba2 trained with --checkpoint-activations.

Root cause

The b_rms_weight / c_rms_weight normalization blocks were added inside the checkpoint_lvl == 1 recompute block by analogy with dt_rms_weight, but unlike delta, B and C are saved post-norm and never recomputed — so they must not be re-normalized.

Fix

Remove the 14 lines that re-apply rms_norm_forward to B and C inside the checkpoint_lvl == 1 block. delta re-normalization is unaffected and remains correct.

Fixes #885.

…int_lvl=1

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[bug] double rmsnorm on B/C in MambaInnerFn.backward at checkpoint_lvl=1

1 participant