Skip to content

DCP Checkpoint Load Fails for _extra_state when training in FSDP2 #1860

@heuristicoder

Description

@heuristicoder

Describe the bug

DCP Checkpoint Load Fails for _extra_state when training in FSDP2

Steps/Code to reproduce bug

Save a DCP checkpoint and try loading it back when using FSDP

[rank4]: ValueError: Size mismatch between saved torch.Size([2322]) and current: torch.Size([4]) for model.diffusion_trans
former.layers.0.feed_forward.ffn._extra_state

Expected behavior

We should be able to load back the model properly to resume training.

Environment overview (please complete the following information)

  • Environment location: Baremetal
  • Method of Transformer Engine install: pip install, v2.3
  • If method of install is [Docker], provide docker pull & docker run commands used

Environment details

If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:

  • OS version: Ubuntu 22.04
  • PyTorch version: 2.7
  • Python version: 3.12
  • Transformer Engine version: v2.3
  • CUDA version: 12.8
  • CUDNN version: ~9

Device details

  • GPU model: H100s

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions