You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
_Note that `torch.save`-converted Megatron-FSDP distributed checkpoints (DCP) cannot be loaded directly into `MegatronFSDP` module classes, because Megatron-FSDP expects a deterministic unevenly sharded checkpoint when loading using DCP. To load a non-distributed checkpoint for training with Megatron-FSDP, simply load the checkpoint into the unsharded model before calling `fully_shard`!_
140
+
#### Megatron-FSDP Checkpoint State Caveats
141
+
142
+
_Note that `torch.save`-converted distributed checkpoints (DCP) cannot be loaded directly into `MegatronFSDP` module classes, because Megatron-FSDP expects an unevenly-sharded DCP checkpoint with metadata not available in `torch.save` checkpoints that defines the distributed read and write sharding strategy for DCP load and save respectively. To load a non-distributed checkpoint for training with Megatron-FSDP, simply load the checkpoint into the unsharded model before calling `fully_shard` as an alternative to loading in a DCP checkpoint after `fully_shard`!_
86
143
87
144
```python
145
+
from checkpoint import load_torch_checkpoint
146
+
88
147
# Initialize model.
89
148
model = build_vit_model(cfg, device_mesh)
90
149
91
-
# Load model checkpoint. Remove the "module." prefix from the keys from Megatron-FSDP,
92
-
# which is the main discrepancy between Megatron-FSDP and normal checkpoints.
93
-
# Must load with weights_only=False if you have an optimizer state in your checkpoint.
TODO(@cspades): For converting DCP directly to HuggingFace SafeTensors checkpoints, you can look into: https://pytorch.org/blog/huggingface-safetensors-support-in-pytorch-distributed-checkpointing/
160
+
161
+
### Inference
162
+
163
+
[infer.py](./infer.py) is an example inference script that loads in a non-distributed `torch.save` checkpoint into an un-sharded ViT.
164
+
165
+
For inference with Megatron-FSDP, refer to the `fully_shard` + `load_dcp_checkpoint` pattern in [train.py](./train.py) / [checkpoint.py](./checkpoint.py) and described in [Megatron-FSDP DCP](#megatron-fsdp-dcp).
0 commit comments