Skip to content

Commit 44e160b

Browse files
committed
Refactor everything, add checkpointing utilities, and add inference scripts and tests using Torch.save checkpoint loading.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent 37c0562 commit 44e160b

12 files changed

Lines changed: 741 additions & 511 deletions

recipes/vit/README.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,64 @@ which will train on a local tiny 5-class version of [ImageNet](https://image-net
4545
The TIMM-derived model code for the ViT can be found in [`vit.py`](vit.py), and data utilities for ImageNet can be found in [`imagenet_*.py`](imagenet_dataset.py).
4646

4747
Various configuration options common in computer vision modeling can be found in [config](./config/).
48+
49+
#### Checkpoint Conversion
50+
51+
To convert DCP checkpoints to non-distributed Torch checkpoints, and vice-versa, you can run the following command from `torch`:
52+
53+
```
54+
python -m torch.distributed.checkpoint.format_utils --help
55+
usage: format_utils.py [-h] {torch_to_dcp,dcp_to_torch} src dst
56+
57+
positional arguments:
58+
{torch_to_dcp,dcp_to_torch}
59+
Conversion mode
60+
src Path to the source model
61+
dst Path to the destination model
62+
63+
options:
64+
-h, --help show this help message and exit
65+
```
66+
67+
For example:
68+
69+
```
70+
python -m torch.distributed.checkpoint.format_utils dcp_to_torch step_75_loss_1.725 torch_ckpt_test.pt
71+
```
72+
73+
or:
74+
75+
```
76+
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
77+
78+
# Convert DCP model checkpoint to torch.save format.
79+
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_PATH)
80+
81+
# Convert torch.save model checkpoint back to DCP format.
82+
torch_save_to_dcp(TORCH_SAVE_CHECKPOINT_PATH, f"{CHECKPOINT_DIR}_new")
83+
```
84+
85+
_Note that `torch.save`-converted Megatron-FSDP distributed checkpoints (DCP) cannot be loaded directly into `MegatronFSDP` module classes, because Megatron-FSDP expects a deterministic unevenly sharded checkpoint when loading using DCP. To load a non-distributed checkpoint for training with Megatron-FSDP, simply load the checkpoint into the unsharded model before calling `fully_shard`!_
86+
87+
```python
88+
# Initialize model.
89+
model = build_vit_model(cfg, device_mesh)
90+
91+
# Load model checkpoint. Remove the "module." prefix from the keys from Megatron-FSDP,
92+
# which is the main discrepancy between Megatron-FSDP and normal checkpoints.
93+
# Must load with weights_only=False if you have an optimizer state in your checkpoint.
94+
# NOTE(@cspades): `from checkpoint import load_torch_checkpoint`
95+
# -> load_torch_checkpoint(megatron_fsdp=True)
96+
model_checkpoint = {
97+
(k.strip("module.") if megatron_fsdp else k): v
98+
for k, v in torch.load(checkpoint_path, weights_only=False)["model"].items()
99+
}
100+
# Load with strict=False because the checkpoint may have TE-specific keys that are not
101+
# necessary for inference.
102+
model.load_state_dict(model_checkpoint, strict=False)
103+
104+
# Fully-shard.
105+
model = fully_shard_model(...)
106+
```
107+
108+
TODO(@cspades): For converting DCP directly to HuggingFace SafeTensors checkpoints, you can look into: https://pytorch.org/blog/huggingface-safetensors-support-in-pytorch-distributed-checkpointing/

recipes/vit/checkpoint.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
from pathlib import Path
19+
20+
import torch
21+
import torch.distributed.checkpoint
22+
23+
24+
_logger = logging.getLogger(__name__)
25+
26+
27+
def load_torch_checkpoint(model, checkpoint_path, megatron_fsdp=False):
28+
"""Load a Torch checkpoint from checkpoint_path into an unsharded model.
29+
Used for converting existing TIMM or Torch checkpoints into a freshly initialized
30+
model prior to sharding with Megatron-FSDP.
31+
32+
If the checkpoint was created from a Megatron-FSDP DCP checkpoint, then setting
33+
megatron_fsdp=True is required and strips a "module." prefix from the keys.
34+
35+
Docs: https://docs.pytorch.org/tutorials/beginner/saving_loading_models.html
36+
"""
37+
# Load model checkpoint. Remove the "module." prefix from the keys from Megatron-FSDP,
38+
# which is the main discrepancy between Megatron-FSDP and normal checkpoints.
39+
# Must load with weights_only=False if you have an optimizer state in your checkpoint.
40+
model_checkpoint = {
41+
(k.strip("module.") if megatron_fsdp else k): v
42+
for k, v in torch.load(checkpoint_path, weights_only=False)["model"].items()
43+
}
44+
# Warn about Megatron-FSDP checkpoints.
45+
first_key = next(iter(model_checkpoint))
46+
if first_key.startswith("module.") and not megatron_fsdp:
47+
_logger.warning(
48+
f"Checkpoint state dictionary keys ({first_key}) may be prefixed "
49+
"with 'modele.' if converted from a Megatron-FSDP DCP checkpoint."
50+
"Set megatron_fsdp=True to automatically strip the prefix."
51+
)
52+
# Load with strict=False because the checkpoint may have
53+
# TE-specific keys that are not necessary for inference.
54+
model.load_state_dict(model_checkpoint, strict=False)
55+
56+
57+
def load_dcp_checkpoint(checkpoint_path, model=None, optimizer=None):
58+
"""Load a Torch DCP checkpoint from checkpoint_path into model and optimizer.
59+
60+
Docs: https://docs.pytorch.org/docs/stable/distributed.checkpoint.html
61+
"""
62+
# Load model and optimizer checkpoints.
63+
state_dict = {}
64+
if model is not None:
65+
state_dict["model"] = model.state_dict()
66+
if optimizer is not None:
67+
state_dict["optimizer"] = optimizer.state_dict()
68+
torch.distributed.checkpoint.load(state_dict, checkpoint_id=checkpoint_path)
69+
model.load_state_dict(state_dict["model"])
70+
optimizer.load_state_dict(state_dict["optimizer"])
71+
72+
73+
def load_auto_resume_checkpoint(cfg, model, optimizer):
74+
"""Auto-resume training from the latest checkpoint.
75+
76+
Checkpoint directories should adhere to the simple format: step_<step_idx>_loss_<loss_value>
77+
If cfg.training.checkpoint.resume_from_metric is '+' or '-', then the loss_value is utilized
78+
for determining the optimal checkpoint to resume from. Otherwise, the latest checkpoint by
79+
modification time is chosen for resumption.
80+
81+
Args:
82+
cfg: Hydra config.
83+
model: Model to load checkpoints into.
84+
optimizer: Optimizer to load checkpoints into.
85+
86+
Returns:
87+
The latest step index to resume from.
88+
"""
89+
# Auto-Resume: Load latest model and optimizer checkpoints.
90+
latest_step_idx = 0
91+
if cfg.training.checkpoint.path and Path(cfg.training.checkpoint.path).exists():
92+
# Get latest checkpoint sub-directory, which should ONLY contain Torch DCP checkpoint sub-directories.
93+
subdirs = [x.absolute() for x in Path(cfg.training.checkpoint.path).iterdir() if x.is_dir()]
94+
if len(subdirs) > 0:
95+
# We expect a checkpoint named as: step_<step_idx>_loss_<loss_value>.
96+
# Get the latest step, the directory with the most recent modification time.
97+
opt_metric_coeff = 1 if cfg.training.checkpoint.resume_from_metric == "+" else -1
98+
latest_subdir = max(
99+
subdirs,
100+
key=lambda x: (
101+
opt_metric_coeff * float(x.name.split("_")[3])
102+
if cfg.training.checkpoint.resume_from_metric
103+
else 0,
104+
x.stat().st_mtime,
105+
),
106+
)
107+
# Track latest step to continue training from.
108+
latest_step_idx = int(latest_subdir.name.split("_")[1])
109+
# Load model and optimizer checkpoints.
110+
load_dcp_checkpoint(latest_subdir, model, optimizer)
111+
if torch.distributed.get_rank() == 0:
112+
_logger.info(f"Loaded latest model and optimizer checkpoints from: {latest_subdir}")
113+
114+
# Return the auto-resumed step index for training progression.
115+
return latest_step_idx
116+
117+
118+
def save_dcp_checkpoint(checkpoint_path, model=None, optimizer=None):
119+
"""Save a Torch DCP checkpoint of the model and optimizer to checkpoint_path.
120+
121+
Docs: https://docs.pytorch.org/docs/stable/distributed.checkpoint.html
122+
"""
123+
# Save model and optimizer checkpoints.
124+
state_dict = {}
125+
if model is not None:
126+
state_dict["model"] = model.state_dict()
127+
if optimizer is not None:
128+
state_dict["optimizer"] = optimizer.state_dict()
129+
torch.distributed.checkpoint.save(state_dict, checkpoint_id=checkpoint_path)
130+
131+
132+
def save_auto_resumable_checkpoint(cfg, model, optimizer, step_idx, loss_value):
133+
"""Save an auto-resumable checkpoint of the model and optimizer at step_idx.
134+
135+
Checkpoint directories should adhere to the simple format: step_<step_idx>_loss_<loss_value>.
136+
This is used for auto-resumption of training.
137+
138+
Args:
139+
cfg: Hydra config.
140+
model: Model to save checkpoints of.
141+
optimizer: Optimizer to save checkpoints of.
142+
step_idx: Step index to save checkpoint at.
143+
loss_value: Loss value to save checkpoint at.
144+
"""
145+
146+
# Save validated checkpoint.
147+
if cfg.training.checkpoint.path:
148+
# Create checkpoint sub-directory.
149+
ckpt_dir = Path(cfg.training.checkpoint.path) / f"step_{step_idx}_loss_{loss_value:.3f}"
150+
ckpt_dir.mkdir(parents=True, exist_ok=True)
151+
# Save model and optimizer checkpoints.
152+
save_dcp_checkpoint(ckpt_dir, model, optimizer)
153+
# Relax checkpoint permissions, which may be helpful when saving checkpoints in a container owned by root.
154+
mode = 0o777
155+
for dirpath, _, filenames in os.walk(ckpt_dir):
156+
# Change current directory perms.
157+
os.chmod(dirpath, mode)
158+
for filename in filenames:
159+
# Change file perms.
160+
file_path = Path(dirpath) / filename
161+
os.chmod(file_path, mode)
162+
if torch.distributed.get_rank() == 0:
163+
_logger.info(f"Saved validated checkpoint to: {ckpt_dir}")

recipes/vit/config/defaults.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ training:
6666
path: null
6767
resume_from_metric: null
6868

69+
inference:
70+
checkpoint:
71+
path: null
72+
6973
dataset:
7074
num_classes: 100000
7175
num_workers: 0

recipes/vit/config/vit_base_patch16_224.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ training:
6464
path: "./checkpoints/vit"
6565
resume_from_metric: "-" # + = Highest Metric (Score), - = Lowest Metric (Loss)
6666

67+
inference:
68+
checkpoint:
69+
path: "./checkpoints/vit/torch_ckpt_test.pt"
70+
6771
dataset:
6872
num_classes: 100000
6973
num_workers: 4

recipes/vit/config/vit_te_base_patch16_224.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,7 @@ training:
1010
checkpoint:
1111
path: "./checkpoints/vit_te"
1212
resume_from_metric: "-" # + = Highest Metric (Score), - = Lowest Metric (Loss)
13+
14+
inference:
15+
checkpoint:
16+
path: "./checkpoints/vit_te/torch_ckpt_test.pt"

recipes/vit/distributed.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
from contextlib import contextmanager
18+
19+
import torch
20+
21+
22+
@contextmanager
23+
def initialize_distributed(cfg):
24+
"""
25+
Setup the DeviceMesh for distributed training.
26+
27+
Args:
28+
cfg: Hydra config.
29+
30+
Yields:
31+
device_mesh: The DeviceMesh.
32+
33+
Raises:
34+
ValueError: If the parallelism sizes are invalid.
35+
"""
36+
# Initialize distributed training environment.
37+
torch.distributed.init_process_group()
38+
39+
# Associate all future device operations in the current process
40+
# with a uniquely-indexed local device, e.g. "cuda:0" on Rank 0.
41+
local_rank = int(os.getenv("LOCAL_RANK", torch.distributed.get_rank()))
42+
torch.cuda.set_device(local_rank)
43+
44+
# Initialize DeviceMesh. Validate parallelism sizes.
45+
# TODO(@cspades): Will add TE-backed context parallelism (CP) in the future, just need to
46+
# modify the ViT model to shard the sequence dimension after tokenization. For now, we
47+
# setup the CP dimension for demonstrating how to use DeviceMesh and CP with Megatron-FSDP.
48+
if cfg.distributed.dp_inter * cfg.distributed.dp_shard * cfg.distributed.cp != torch.distributed.get_world_size():
49+
raise ValueError(
50+
f"Invalid parallelism sizes: dp_inter({cfg.distributed.dp_inter}) * dp_shard({cfg.distributed.dp_shard}) * cp({cfg.distributed.cp}) * tp(1) != world_size({torch.distributed.get_world_size()})"
51+
)
52+
device_mesh = torch.distributed.device_mesh.init_device_mesh(
53+
"cuda",
54+
mesh_shape=(
55+
cfg.distributed.dp_inter,
56+
cfg.distributed.dp_shard,
57+
cfg.distributed.cp,
58+
1, # Needed to use TransformerEngine layers with Megatron-FSDP. "TP is always 1."
59+
),
60+
mesh_dim_names=("dp_inter", "dp_shard", "cp", "tp"),
61+
)
62+
63+
# Sub-meshes (possibly) required for Megatron-FSDP.
64+
# WARNING: These have a tendency to be deleted by Torch. Save references
65+
# or pass them to all classes or functions that use them.
66+
# DP: Only relevant when using HSDP, where we need the flattened DP group for data parallelism. (Otherwise, just pass dp_shard.)
67+
device_mesh[("dp_inter", "dp_shard")]._flatten("dp")
68+
# DP-Shard-CP: Only required if using CP. Otherwise, just pass dp_shard to FSDP.
69+
device_mesh[("dp_shard", "cp")]._flatten("dp_cp_shard")
70+
# HSDP (DP-CP): Only required if using HSDP. Otherwise, don't pass hybrid_fsdp_group to Megatron-FSDP.
71+
device_mesh[("dp_inter", "dp_shard", "cp")]._flatten("hsdp")
72+
73+
# Yield DeviceMesh.
74+
yield device_mesh
75+
76+
# Destroy process group.
77+
torch.distributed.destroy_process_group()

recipes/vit/imagenet_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,12 @@ def __init__(
205205
if isinstance(class_map, str):
206206
class_to_idx = load_class_map(class_map)
207207
elif isinstance(class_map, dict):
208-
assert dict, "Class-to-Index mapping dict must be non-empty."
208+
assert class_map, "Class-to-Index mapping dict must be non-empty."
209209
class_to_idx = class_map
210210
if isinstance(label_map, str):
211211
image_to_label = load_image_labels(label_map)
212212
elif isinstance(label_map, dict):
213-
assert dict, "Image-to-Label mapping dict must be non-empty."
213+
assert label_map, "Image-to-Label mapping dict must be non-empty."
214214
image_to_label = label_map
215215
self.samples, self.class_to_idx = find_images_and_targets(
216216
root,

0 commit comments

Comments
 (0)