the bug was introduced in commit a38f953 from PR #543
in mamba_ssm/ops/selective_scan_interface.py
MambaInnerFn.forward:
when b_rms_weight or c_rms_weight is provided, B and C are get rmsnorm in place and the tensors are stored via ctx.save_for_backward.
MambaInnerFn.backward with checkpoint_lvl==1:
B and C are treated the same way as delta: it recomputes them and applies rmsnorm (again).
But unlike delta which is correctly set to None before saving so it truly IS recomputed from scratch,
B and C are saved post-norm (not set to None), so the backward loads the post rmsnorm values and normalizes them a second time before passing them to selective_scan_cuda.bwd.
|
if b_rms_weight is not None: |
|
# Recompute & RMSNorm B |
|
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous() |
|
B = rms_norm_forward( |
|
B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps |
|
) |
|
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() |
|
if c_rms_weight is not None: |
|
# Recompute & RMSNorm C |
|
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous() |
|
C = rms_norm_forward( |
|
C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps |
|
) |
|
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() |
|
|
can be just removed probably?
let me know what do you think? @younesbelkada
were any issues observed during training? They seem quite subtle
i currently have 2 suggestions for a fix
the bug was introduced in commit a38f953 from PR #543
in
mamba_ssm/ops/selective_scan_interface.pyMambaInnerFn.forward:
when b_rms_weight or c_rms_weight is provided, B and C are get rmsnorm in place and the tensors are stored via ctx.save_for_backward.
MambaInnerFn.backward with checkpoint_lvl==1:
B and C are treated the same way as delta: it recomputes them and applies rmsnorm (again).
But unlike delta which is correctly set to None before saving so it truly IS recomputed from scratch,
B and C are saved post-norm (not set to None), so the backward loads the post rmsnorm values and normalizes them a second time before passing them to selective_scan_cuda.bwd.
mamba/mamba_ssm/ops/selective_scan_interface.py
Lines 303 to 317 in 126bbf2
can be just removed probably?
let me know what do you think? @younesbelkada
were any issues observed during training? They seem quite subtle
i currently have 2 suggestions for a fix