Skip to content

Commit d89d170

Browse files
committed
Fix TE sharding, README.md updates to add inference instructions.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent a065450 commit d89d170

4 files changed

Lines changed: 83 additions & 16 deletions

File tree

recipes/vit/README.md

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,58 @@ The TIMM-derived model code for the ViT can be found in [`vit.py`](vit.py), and
4646

4747
Various configuration options common in computer vision modeling can be found in [config](./config/).
4848

49+
### Checkpointing
50+
51+
#### Megatron-FSDP DCP
52+
53+
To save Megatron-FSDP distributed checkpoints, refer to the following helper functions in [checkpoint.py](./checkpoint.py):
54+
55+
```python
56+
import torch
57+
58+
59+
def save_dcp_checkpoint(checkpoint_path, model=None, optimizer=None):
60+
"""Save a Torch DCP checkpoint of the model and optimizer to checkpoint_path.
61+
62+
Docs: https://docs.pytorch.org/docs/stable/distributed.checkpoint.html
63+
"""
64+
# Save model and optimizer checkpoints.
65+
state_dict = {}
66+
if model is not None:
67+
state_dict["model"] = model.state_dict()
68+
if optimizer is not None:
69+
state_dict["optimizer"] = optimizer.state_dict()
70+
torch.distributed.checkpoint.save(state_dict, checkpoint_id=checkpoint_path)
71+
72+
73+
def load_dcp_checkpoint(checkpoint_path, model=None, optimizer=None):
74+
"""Load a Torch DCP checkpoint from checkpoint_path into model and optimizer.
75+
76+
Docs: https://docs.pytorch.org/docs/stable/distributed.checkpoint.html
77+
"""
78+
# Load model and optimizer checkpoints.
79+
state_dict = {}
80+
if model is not None:
81+
state_dict["model"] = model.state_dict()
82+
if optimizer is not None:
83+
state_dict["optimizer"] = optimizer.state_dict()
84+
torch.distributed.checkpoint.load(state_dict, checkpoint_id=checkpoint_path)
85+
if model is not None:
86+
model.load_state_dict(state_dict["model"])
87+
if optimizer is not None:
88+
optimizer.load_state_dict(state_dict["optimizer"])
89+
```
90+
91+
which can be loaded directly into the `MegatronFSDP` model:
92+
93+
```python
94+
# Create a MegatronFSDP model and optimizer.
95+
model, optimizer = fully_shard(model, optimizer, ...)
96+
97+
# Load Megatron-FSDP DCP checkpoint into model and/or optimizer.
98+
load_dcp_checkpoint(CKPT_PATH, model=model, optimizer=optimizer)
99+
```
100+
49101
#### Checkpoint Conversion
50102

51103
To convert DCP checkpoints to non-distributed Torch checkpoints, and vice-versa, you can run the following command from `torch`:
@@ -72,8 +124,11 @@ python -m torch.distributed.checkpoint.format_utils dcp_to_torch step_75_loss_1.
72124

73125
or:
74126

75-
```
76-
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
127+
```python
128+
from torch.distributed.checkpoint.format_utils import (
129+
dcp_to_torch_save,
130+
torch_save_to_dcp,
131+
)
77132

78133
# Convert DCP model checkpoint to torch.save format.
79134
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_PATH)
@@ -82,27 +137,29 @@ dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_PATH)
82137
torch_save_to_dcp(TORCH_SAVE_CHECKPOINT_PATH, f"{CHECKPOINT_DIR}_new")
83138
```
84139

85-
_Note that `torch.save`-converted Megatron-FSDP distributed checkpoints (DCP) cannot be loaded directly into `MegatronFSDP` module classes, because Megatron-FSDP expects a deterministic unevenly sharded checkpoint when loading using DCP. To load a non-distributed checkpoint for training with Megatron-FSDP, simply load the checkpoint into the unsharded model before calling `fully_shard`!_
140+
#### Megatron-FSDP Checkpoint State Caveats
141+
142+
_Note that `torch.save`-converted distributed checkpoints (DCP) cannot be loaded directly into `MegatronFSDP` module classes, because Megatron-FSDP expects an unevenly-sharded DCP checkpoint with metadata not available in `torch.save` checkpoints that defines the distributed read and write sharding strategy for DCP load and save respectively. To load a non-distributed checkpoint for training with Megatron-FSDP, simply load the checkpoint into the unsharded model before calling `fully_shard` as an alternative to loading in a DCP checkpoint after `fully_shard`!_
86143

87144
```python
145+
from checkpoint import load_torch_checkpoint
146+
88147
# Initialize model.
89148
model = build_vit_model(cfg, device_mesh)
90149

91-
# Load model checkpoint. Remove the "module." prefix from the keys from Megatron-FSDP,
92-
# which is the main discrepancy between Megatron-FSDP and normal checkpoints.
93-
# Must load with weights_only=False if you have an optimizer state in your checkpoint.
94-
# NOTE(@cspades): `from checkpoint import load_torch_checkpoint`
95-
# -> load_torch_checkpoint(megatron_fsdp=True)
96-
model_checkpoint = {
97-
(k.strip("module.") if megatron_fsdp else k): v
98-
for k, v in torch.load(checkpoint_path, weights_only=False)["model"].items()
99-
}
100-
# Load with strict=False because the checkpoint may have TE-specific keys that are not
101-
# necessary for inference.
102-
model.load_state_dict(model_checkpoint, strict=False)
150+
# Load torch.save model checkpoint. If the checkpoint was converted
151+
# from a DCP checkpoint produced by Megatron-FSDP, set megatron_fsdp=True,
152+
# which simply strips the "module." prefix from the state dictionary.
153+
load_torch_checkpoint(CKPT_PATH, model, megatron_fsdp=True)
103154

104155
# Fully-shard.
105-
model = fully_shard_model(...)
156+
model, _ = fully_shard(model, ...)
106157
```
107158

108159
TODO(@cspades): For converting DCP directly to HuggingFace SafeTensors checkpoints, you can look into: https://pytorch.org/blog/huggingface-safetensors-support-in-pytorch-distributed-checkpointing/
160+
161+
### Inference
162+
163+
[infer.py](./infer.py) is an example inference script that loads in a non-distributed `torch.save` checkpoint into an un-sharded ViT.
164+
165+
For inference with Megatron-FSDP, refer to the `fully_shard` + `load_dcp_checkpoint` pattern in [train.py](./train.py) / [checkpoint.py](./checkpoint.py) and described in [Megatron-FSDP DCP](#megatron-fsdp-dcp).

recipes/vit/config/defaults.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ fsdp:
5151
fsdp_unit_modules:
5252
- vit.Block
5353
- vit.PatchEmbed
54+
- vit.AttentionPoolLatent
5455
- torch.nn.LayerNorm
5556
- torch.nn.Linear
5657
use_hybrid_fsdp: true

recipes/vit/config/vit_base_patch16_224.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ fsdp:
4949
fsdp_unit_modules:
5050
- vit.Block
5151
- vit.PatchEmbed
52+
- vit.AttentionPoolLatent
5253
- torch.nn.LayerNorm
5354
- torch.nn.Linear
5455
use_hybrid_fsdp: true

recipes/vit/config/vit_te_base_patch16_224.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@ defaults:
66
model:
77
transformer_engine: true
88

9+
fsdp:
10+
fsdp_unit_modules:
11+
- transformer_engine.pytorch.TransformerLayer
12+
- vit.PatchEmbed
13+
- vit.AttentionPoolLatent
14+
- torch.nn.LayerNorm
15+
- torch.nn.Linear
16+
917
training:
1018
checkpoint:
1119
path: "./checkpoints/vit_te"

0 commit comments

Comments
 (0)