Skip to content

Commit 0fffd7b

Browse files
committed
Minor function and config changes.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent a1f46b1 commit 0fffd7b

10 files changed

Lines changed: 83 additions & 50 deletions

File tree

recipes/vit/checkpoint.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
_logger = logging.getLogger(__name__)
2525

2626

27-
def load_torch_checkpoint(model, checkpoint_path, megatron_fsdp=False):
27+
def load_torch_checkpoint(checkpoint_path, model, megatron_fsdp=False):
2828
"""Load a Torch checkpoint from checkpoint_path into an unsharded model.
2929
Used for converting existing TIMM or Torch checkpoints into a freshly initialized
3030
model prior to sharding with Megatron-FSDP.
@@ -34,19 +34,18 @@ def load_torch_checkpoint(model, checkpoint_path, megatron_fsdp=False):
3434
3535
Docs: https://docs.pytorch.org/tutorials/beginner/saving_loading_models.html
3636
"""
37-
# Load model checkpoint. Remove the "module." prefix from the keys from Megatron-FSDP,
38-
# which is the main discrepancy between Megatron-FSDP and normal checkpoints.
39-
# Must load with weights_only=False if you have an optimizer state in your checkpoint.
40-
model_checkpoint = {
41-
(k.strip("module.") if megatron_fsdp else k): v
42-
for k, v in torch.load(checkpoint_path, weights_only=False)["model"].items()
43-
}
37+
# Load model checkpoint. Must load with weights_only=False
38+
# if you have an optimizer state in your checkpoint.
39+
checkpoint = torch.load(checkpoint_path, weights_only=False)
40+
# Remove the "module." prefix from the keys of checkpoints
41+
# derived from Megatron-FSDP.
42+
model_checkpoint = {(k.removeprefix("module.") if megatron_fsdp else k): v for k, v in checkpoint["model"].items()}
4443
# Warn about Megatron-FSDP checkpoints.
4544
first_key = next(iter(model_checkpoint))
4645
if first_key.startswith("module.") and not megatron_fsdp:
4746
_logger.warning(
4847
f"Checkpoint state dictionary keys ({first_key}) may be prefixed "
49-
"with 'modele.' if converted from a Megatron-FSDP DCP checkpoint."
48+
"with 'module.' if converted from a Megatron-FSDP DCP checkpoint."
5049
"Set megatron_fsdp=True to automatically strip the prefix."
5150
)
5251
# Load with strict=False because the checkpoint may have
@@ -66,8 +65,10 @@ def load_dcp_checkpoint(checkpoint_path, model=None, optimizer=None):
6665
if optimizer is not None:
6766
state_dict["optimizer"] = optimizer.state_dict()
6867
torch.distributed.checkpoint.load(state_dict, checkpoint_id=checkpoint_path)
69-
model.load_state_dict(state_dict["model"])
70-
optimizer.load_state_dict(state_dict["optimizer"])
68+
if model is not None:
69+
model.load_state_dict(state_dict["model"])
70+
if optimizer is not None:
71+
optimizer.load_state_dict(state_dict["optimizer"])
7172

7273

7374
def load_auto_resume_checkpoint(cfg, model, optimizer):

recipes/vit/config/defaults.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ optimizer:
4141
weight_decay: 0.01
4242

4343
distributed:
44-
dp_inter: 1
44+
dp_outer: 1
4545
dp_shard: 1
4646
cp: 1
4747

@@ -69,6 +69,8 @@ training:
6969
inference:
7070
checkpoint:
7171
path: null
72+
format: null
73+
megatron_fsdp: null
7274

7375
dataset:
7476
num_classes: 100000

recipes/vit/config/vit_base_patch16_224.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ model:
3939
channels_last: false
4040

4141
distributed:
42-
dp_inter: 1
42+
dp_outer: 1
4343
dp_shard: 1
4444
cp: 1
4545

@@ -66,7 +66,12 @@ training:
6666

6767
inference:
6868
checkpoint:
69-
path: "./checkpoints/vit/torch_ckpt_test.pt"
69+
path: null
70+
# Load a DCP->Torch converted checkpoint for inference without Megatron-FSDP.
71+
# Otherwise, set this to "torch_dcp" if using Megatron-FSDP for inference.
72+
# If the checkpoint was not trained with Megatron-FSDP, then set megatron_fsdp to false.
73+
format: "torch"
74+
megatron_fsdp: true
7075

7176
dataset:
7277
num_classes: 100000

recipes/vit/config/vit_te_base_patch16_224.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,9 @@ training:
1313

1414
inference:
1515
checkpoint:
16-
path: "./checkpoints/vit_te/torch_ckpt_test.pt"
16+
path: null
17+
# Load a DCP->Torch converted checkpoint for inference without Megatron-FSDP.
18+
# Otherwise, set this to "torch_dcp" if using Megatron-FSDP for inference.
19+
# If the checkpoint was not trained with Megatron-FSDP, then set megatron_fsdp to false.
20+
format: "torch"
21+
megatron_fsdp: true

recipes/vit/distributed.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,20 @@
2020

2121

2222
@contextmanager
23-
def initialize_distributed(cfg):
23+
def initialize_distributed(
24+
dp_outer: int = 1,
25+
dp_shard: int = 1,
26+
cp: int = 1,
27+
tp: int = 1,
28+
):
2429
"""
2530
Setup the DeviceMesh for distributed training.
2631
2732
Args:
28-
cfg: Hydra config.
33+
dp_outer: The size of the data parallelism outer dimension.
34+
dp_shard: The size of the data parallelism shard dimension.
35+
cp: The size of the context parallelism dimension.
36+
tp: The size of the tensor parallelism dimension.
2937
3038
Yields:
3139
device_mesh: The DeviceMesh.
@@ -45,30 +53,30 @@ def initialize_distributed(cfg):
4553
# TODO(@cspades): Will add TE-backed context parallelism (CP) in the future, just need to
4654
# modify the ViT model to shard the sequence dimension after tokenization. For now, we
4755
# setup the CP dimension for demonstrating how to use DeviceMesh and CP with Megatron-FSDP.
48-
if cfg.distributed.dp_inter * cfg.distributed.dp_shard * cfg.distributed.cp != torch.distributed.get_world_size():
56+
if dp_outer * dp_shard * cp != torch.distributed.get_world_size():
4957
raise ValueError(
50-
f"Invalid parallelism sizes: dp_inter({cfg.distributed.dp_inter}) * dp_shard({cfg.distributed.dp_shard}) * cp({cfg.distributed.cp}) * tp(1) != world_size({torch.distributed.get_world_size()})"
58+
f"Invalid parallelism sizes: dp_outer({dp_outer}) * dp_shard({dp_shard}) * cp({cp}) * tp({tp}) != world_size({torch.distributed.get_world_size()})"
5159
)
5260
device_mesh = torch.distributed.device_mesh.init_device_mesh(
5361
"cuda",
5462
mesh_shape=(
55-
cfg.distributed.dp_inter,
56-
cfg.distributed.dp_shard,
57-
cfg.distributed.cp,
58-
1, # Needed to use TransformerEngine layers with Megatron-FSDP. "TP is always 1."
63+
dp_outer,
64+
dp_shard,
65+
cp,
66+
tp, # Needed to use TransformerEngine layers with Megatron-FSDP.
5967
),
60-
mesh_dim_names=("dp_inter", "dp_shard", "cp", "tp"),
68+
mesh_dim_names=("dp_outer", "dp_shard", "cp", "tp"),
6169
)
6270

6371
# Sub-meshes (possibly) required for Megatron-FSDP.
6472
# WARNING: These have a tendency to be deleted by Torch. Save references
6573
# or pass them to all classes or functions that use them.
6674
# DP: Only relevant when using HSDP, where we need the flattened DP group for data parallelism. (Otherwise, just pass dp_shard.)
67-
device_mesh[("dp_inter", "dp_shard")]._flatten("dp")
75+
device_mesh[("dp_outer", "dp_shard")]._flatten("dp")
6876
# DP-Shard-CP: Only required if using CP. Otherwise, just pass dp_shard to FSDP.
6977
device_mesh[("dp_shard", "cp")]._flatten("dp_cp_shard")
7078
# HSDP (DP-CP): Only required if using HSDP. Otherwise, don't pass hybrid_fsdp_group to Megatron-FSDP.
71-
device_mesh[("dp_inter", "dp_shard", "cp")]._flatten("hsdp")
79+
device_mesh[("dp_outer", "dp_shard", "cp")]._flatten("hsdp")
7280

7381
# Yield DeviceMesh.
7482
yield device_mesh

recipes/vit/infer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ def main(cfg) -> None:
3131
"""
3232
Inference script for ViT. Non-distributed inference.
3333
"""
34-
with initialize_distributed(cfg) as device_mesh:
34+
with initialize_distributed(**cfg.distributed) as device_mesh:
3535
# Init ViT.
3636
model = build_vit_model(cfg, device_mesh).cuda()
3737

38-
# Load model checkpoint trained using Megatron-FSDP.
39-
load_torch_checkpoint(model, cfg.inference.checkpoint.path, megatron_fsdp=True)
38+
# Load torch.save (non-distributed) model checkpoint trained using (or not using) Megatron-FSDP.
39+
load_torch_checkpoint(
40+
cfg.inference.checkpoint.path, model, megatron_fsdp=cfg.inference.checkpoint.megatron_fsdp
41+
)
4042
logger.info(f"Model: {model}")
4143

4244
# Mock input.

recipes/vit/test_infer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,15 @@ def test_infer(monkeypatch, tmp_path, config_name):
4545
config_name=config_name,
4646
overrides=[
4747
f"++inference.checkpoint.path={test_ckpt_path}",
48+
# Using a torch.save mock checkpoint for inference.
49+
"++inference.checkpoint.format=torch",
50+
# Using a non-Megatron-FSDP mock checkpoint for inference.
51+
"++inference.checkpoint.megatron_fsdp=false",
4852
],
4953
)
5054

5155
# Write a test checkpoint.
52-
with initialize_distributed(vit_config) as device_mesh:
56+
with initialize_distributed(**vit_config.distributed) as device_mesh:
5357
# Init ViT.
5458
model = build_vit_model(vit_config, device_mesh).cuda()
5559
# Write checkpoint.

recipes/vit/test_train.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,29 +37,27 @@ def test_train(monkeypatch, tmp_path, config_name, init_model_with_meta_device):
3737

3838
# Initialize training config.
3939
recipe_dir = Path(__file__).parent
40+
training_ckpt_path = Path(tmp_path) / "test_train_checkpoints"
4041
with initialize_config_dir(config_dir=str(recipe_dir / "config"), version_base="1.2"):
4142
vit_config = compose(
4243
config_name=config_name,
4344
overrides=[
44-
"++training.steps=10",
45-
"++training.val_interval=10",
45+
"++training.steps=5",
46+
"++training.val_interval=5",
4647
"++training.log_interval=1",
47-
f"++training.checkpoint.path={Path(tmp_path) / 'ckpt'}",
48+
f"++training.checkpoint.path={training_ckpt_path}",
4849
"++profiling.torch_memory_profile=false",
4950
"++profiling.wandb=false",
5051
f"++fsdp.init_model_with_meta_device={init_model_with_meta_device}",
5152
],
5253
)
53-
vit_resume_config = deepcopy(vit_config)
54-
vit_resume_config.training.steps = 10
5554

5655
main(vit_config)
5756

5857
# Verify checkpoints were created.
59-
assert sum(1 for item in (Path(tmp_path) / "ckpt").iterdir() if item.is_dir()) == 1, (
60-
"Expected 1 checkpoint with 10 training steps and validation interval of 10."
58+
assert sum(1 for item in training_ckpt_path.iterdir() if item.is_dir()) == 1, (
59+
"Expected 1 checkpoint with 5 training steps and validation interval of 5."
6160
)
6261

63-
# Auto-resume training from checkpoint. For this test, we auto-resume from the best checkpoint,
64-
# so depending on what the best checkpoint is, we may have more than 5 checkpoints.
65-
main(vit_resume_config)
62+
# Auto-resume training from checkpoint. For this test, we auto-resume from the best checkpoint.
63+
main(vit_config)

recipes/vit/train.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def main(cfg) -> None:
4545
"""Train a ViT model on ImageNet using Megatron-FSDP and TransformerEngine (TE)."""
4646

4747
# Initialize distributed environment.
48-
with initialize_distributed(cfg) as device_mesh:
48+
with initialize_distributed(**cfg.distributed) as device_mesh:
4949
"""
5050
Profiling
5151
"""
@@ -92,7 +92,7 @@ def main(cfg) -> None:
9292
# Always required to use Megatron-FSDP. What we shard on.
9393
dp_shard_dim="dp_cp_shard",
9494
# Required if using HSDP. The second / intermediate set of data-parallel process groups.
95-
dp_inter_dim="dp_inter",
95+
dp_inter_dim="dp_outer",
9696
# Required if using TP, either from TransformerEngine (TP=1) / Megatron or DTensor-based TP.
9797
tp_dim="tp",
9898
# Required if using HSDP. Created by flattening everything we shard on, e.g. DP-CP.
@@ -142,9 +142,9 @@ def main(cfg) -> None:
142142
sampler=train_sampler,
143143
num_workers=cfg.dataset.num_workers,
144144
# IMPORTANT: persistent_workers=True is required for Megatron-FSDP and
145-
# Torch DCP, because CUDA/NCCL and Dataloader kill each others workers!
145+
# Torch DCP, because CUDA/NCCL and Dataloader kill each others' workers!
146146
# Alternatively, you can set num_workers=0.
147-
persistent_workers=True,
147+
persistent_workers=(cfg.dataset.num_workers > 0),
148148
)
149149
if torch.distributed.get_rank() == 0:
150150
_logger.info(f"Training Dataset Size: {len(imagenet_train_ds)}")
@@ -171,9 +171,9 @@ def main(cfg) -> None:
171171
sampler=val_sampler,
172172
num_workers=cfg.dataset.num_workers,
173173
# IMPORTANT: persistent_workers=True is required for Megatron-FSDP and
174-
# Torch DCP, because CUDA/NCCL and Dataloader kill each others workers!
174+
# Torch DCP, because CUDA/NCCL and Dataloader kill each others' workers!
175175
# Alternatively, you can set num_workers=0.
176-
persistent_workers=True,
176+
persistent_workers=(cfg.dataset.num_workers > 0),
177177
)
178178
if torch.distributed.get_rank() == 0:
179179
_logger.info(f"Validation Dataset Size: {len(imagenet_val_ds)}")
@@ -211,6 +211,7 @@ def main(cfg) -> None:
211211

212212
# Set training mode.
213213
model.train()
214+
optimizer.zero_grad()
214215

215216
# Match model input shape.
216217
if cfg.model.channels_last:

recipes/vit/vit.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,14 @@
6464
import torch
6565
import torch.nn as nn
6666
import torch.nn.functional as F
67-
from transformer_engine.pytorch import TransformerLayer
67+
68+
69+
try:
70+
from transformer_engine.pytorch import TransformerLayer
71+
72+
_TE_INSTALLED = True
73+
except ImportError:
74+
_TE_INSTALLED = False
6875

6976

7077
def build_vit_model(cfg, device_mesh=None, meta_init=False):
@@ -85,7 +92,7 @@ def build_vit_model(cfg, device_mesh=None, meta_init=False):
8592
vit_kwargs = dict(cfg.model.vit)
8693
if meta_init:
8794
vit_kwargs["weight_init"] = None
88-
if cfg.model.transformer_engine:
95+
if cfg.model.transformer_engine and _TE_INSTALLED:
8996
assert device_mesh is not None, "[build_model] device_mesh is required when using TransformerEngine."
9097
vit_kwargs["block_fn"] = TransformerLayer
9198
vit_kwargs["micro_batch_size"] = cfg.dataset.train.batch_size
@@ -1385,7 +1392,7 @@ def __init__(
13851392
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth, device="cpu")] # stochastic depth decay rule
13861393

13871394
self.block_fn = block_fn
1388-
if block_fn == TransformerLayer:
1395+
if _TE_INSTALLED and block_fn == TransformerLayer:
13891396
self.blocks = nn.Sequential(
13901397
*[
13911398
TransformerLayer(
@@ -1464,7 +1471,7 @@ def rescale(param, _layer_id):
14641471
param.div_(math.sqrt(2.0 * _layer_id))
14651472

14661473
for layer_id, layer in enumerate(self.blocks):
1467-
if self.block_fn == TransformerLayer:
1474+
if _TE_INSTALLED and self.block_fn == TransformerLayer:
14681475
rescale(layer.self_attention.proj.weight.data, layer_id + 1)
14691476
rescale(layer.layernorm_mlp.fc2_weight.data, layer_id + 1)
14701477
else:

0 commit comments

Comments
 (0)