Skip to content

Cache ctx.saved_tensors in Mamba3 backward to fix FSDP activation checkpointing#909

Open
anasiri wants to merge 1 commit into
state-spaces:mainfrom
anasiri:fix/checkpoint-saved-tensors-double-access
Open

Cache ctx.saved_tensors in Mamba3 backward to fix FSDP activation checkpointing#909
anasiri wants to merge 1 commit into
state-spaces:mainfrom
anasiri:fix/checkpoint-saved-tensors-double-access

Conversation

@anasiri
Copy link
Copy Markdown

@anasiri anasiri commented Apr 9, 2026

_Mamba3Function.backward() accesses ctx.saved_tensors twice. torch.utils.checkpoint registers unpack hooks that only allow a single access, so the second call raises CheckpointError when FSDP activation checkpointing is enabled:

torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Unpack is being triggered for a tensor that was already unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do so only once. Otherwise please open an issue with details on your use case.

This PR caches ctx.saved_tensors in a local variable so both uses read from the same reference.

…pointing

torch.utils.checkpoint unpack hooks only allow a single unpack per tensor.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant