Skip to content

[bug] double rmsnorm on B/C in MambaInnerFn.backward at checkpoint_lvl=1 #885

@darxradi3nt

Description

@darxradi3nt

the bug was introduced in commit a38f953 from PR #543

in mamba_ssm/ops/selective_scan_interface.py

MambaInnerFn.forward:

when b_rms_weight or c_rms_weight is provided, B and C are get rmsnorm in place and the tensors are stored via ctx.save_for_backward.

MambaInnerFn.backward with checkpoint_lvl==1:

B and C are treated the same way as delta: it recomputes them and applies rmsnorm (again).
But unlike delta which is correctly set to None before saving so it truly IS recomputed from scratch,
B and C are saved post-norm (not set to None), so the backward loads the post rmsnorm values and normalizes them a second time before passing them to selective_scan_cuda.bwd.

if b_rms_weight is not None:
# Recompute & RMSNorm B
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
B = rms_norm_forward(
B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps
)
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if c_rms_weight is not None:
# Recompute & RMSNorm C
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
C = rms_norm_forward(
C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps
)
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()

can be just removed probably?

let me know what do you think? @younesbelkada
were any issues observed during training? They seem quite subtle

i currently have 2 suggestions for a fix

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions