Skip to content

Commit c912f5b

Browse files
committed
Extend FSDP2 unit tests to include DCP checkpointing and parity tests.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent 3aa6075 commit c912f5b

6 files changed

Lines changed: 155 additions & 18 deletions

File tree

tests/pytorch/distributed/run_fsdp2_model.py

Lines changed: 140 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
#
55
# See LICENSE for license information.
66

7+
import argparse
78
import os
89
import sys
9-
import argparse
10+
import shutil
11+
from contextlib import nullcontext
12+
from copy import deepcopy
1013
from dataclasses import dataclass
14+
from pathlib import Path
1115

1216
import transformer_engine.pytorch as te
1317
import transformer_engine.common.recipe
14-
18+
from transformer_engine.pytorch import QuantizedTensor
1519
import torch
1620
import torch.distributed as dist
1721
from torch.distributed.checkpoint import save, load
@@ -27,11 +31,13 @@
2731
from torch.distributed import DeviceMesh
2832
from torch.distributed._composable.fsdp import fully_shard
2933
from torch.distributed.device_mesh import init_device_mesh
30-
from transformer_engine.pytorch import QuantizedTensor
31-
from contextlib import nullcontext
3234

3335
LOCAL_RANK = None
3436

37+
# Needed for `torch.distributed.checkpoint.{save,load}` because
38+
# multiple processes need to write to the same directory.
39+
SHARED_TMP_DIR = "/tmp/pytest-shared-tmp"
40+
3541

3642
@dataclass
3743
class AppState(Stateful):
@@ -63,7 +69,7 @@ def state_dict(self):
6369
# yet get_state_dict / _init_optim_state produce empty Tensors.
6470
# TransformerEngine uses empty Tensors for dummy Parameters.
6571
optimizer_state_dict["state"][fqn] = {}
66-
if fqn.endswith("._extra_state"):
72+
if fqn.endswith("_extra_state"):
6773
# Evict `_extra_state` quantization data from model checkpoint.
6874
model_state_dict.pop(fqn)
6975
return {
@@ -352,7 +358,9 @@ def test_fp8_fsdp2_allgather(model):
352358
# FP32 manual weight allgather
353359
fp32_allgathered_params = {}
354360
for name, param in model.named_parameters():
355-
assert isinstance(param, DTensor)
361+
assert isinstance(
362+
param, DTensor
363+
), f"[test_fp8_fsdp2_allgather] {param} should be a DTensor."
356364
local_tensor = param._local_tensor
357365
device_mesh = param.device_mesh
358366
dist_group = (
@@ -471,7 +479,7 @@ def _train(args):
471479
optimizer = optim.Adam(model.parameters(), lr=1e-3)
472480

473481
"""
474-
Pre-Save Training
482+
FSDP2 Training
475483
"""
476484
for iteration in range(args.iter):
477485
# Zero the parameter gradients
@@ -499,6 +507,131 @@ def _train(args):
499507
if args.fp8_init:
500508
test_fp8_fsdp2_allgather(model)
501509

510+
"""
511+
DCP Checkpoint Testing
512+
"""
513+
# Compute the pre-save model loss to the last random input
514+
# with respect to the last random target.
515+
model.eval()
516+
with te.autocast(enabled=True, recipe=fp8_recipe):
517+
output = model(input_data)
518+
pre_save_loss = F.mse_loss(output, target)
519+
520+
# Save deep copy of the model and optimizer state before checkpointing.
521+
# NOTE(@cspades): deepcopy has issues with DTensors. Just clone().
522+
s1 = {}
523+
for key, val in model.state_dict().items():
524+
s1[key] = val.clone()
525+
optim_state_dict = optimizer.state_dict()
526+
o1 = {"state": {}}
527+
for idx, state in optim_state_dict["state"].items():
528+
o1_state = o1["state"].setdefault(idx, {})
529+
for key, val in state.items():
530+
o1_state[key] = val.clone()
531+
o1["param_groups"] = deepcopy(optim_state_dict["param_groups"])
532+
533+
# Write model to checkpoint.
534+
CKPT_DIR = (
535+
Path(SHARED_TMP_DIR)
536+
/ "run_fsdp2_model"
537+
/ f"dcp-{'_'.join(str(x) for x in args.sharding_dims)}-{args.layer_type}-{args.recipe}-fp8_init_{args.fp8_init}"
538+
)
539+
CKPT_DIR.mkdir(parents=True, exist_ok=True, mode=0o777)
540+
state_dict = {"app": AppState(model=model, optimizer=optimizer)}
541+
torch.distributed.checkpoint.save(state_dict, checkpoint_id=str(CKPT_DIR))
542+
543+
# Perform an extra training step to change the weights such that
544+
# state parity tests will fail unless the checkpoint is loaded
545+
# without any errors or incongruities vs. the saved model state.
546+
model.train()
547+
for iteration in range(args.iter):
548+
with te.autocast(enabled=True, recipe=fp8_recipe):
549+
output = model(torch.randn(inp_shape).to(device))
550+
loss = F.mse_loss(output, torch.randn(out_shape).to(device))
551+
loss.backward()
552+
optimizer.step()
553+
554+
# Load the checkpoint.
555+
state_dict = {"app": AppState(model=model, optimizer=optimizer)}
556+
torch.distributed.checkpoint.load(state_dict=state_dict, checkpoint_id=str(CKPT_DIR))
557+
558+
# Validate checkpoint parity with pre-save state dictionaries.
559+
# Compare pre-save and post-load model state dictionaries.
560+
s2 = model.state_dict()
561+
nonempty_model_state = False
562+
for key in s1.keys() | s2.keys():
563+
if key.endswith("_extra_state"):
564+
# Don't parity test _extra_state. Shape can change after reset_parameters().
565+
continue
566+
v1 = s1.get(key, None)
567+
if isinstance(v1, DTensor):
568+
v1 = v1.to_local()
569+
v2 = s2.get(key, None)
570+
if isinstance(v2, DTensor):
571+
v2 = v2.to_local()
572+
assert (
573+
v1 is not None and v2 is not None
574+
), f"[{key} Not Found] Original Param: {v1} | Checkpoint Param: {v2}"
575+
assert (
576+
v1.shape == v2.shape
577+
), f"[Checkpoint Param {key} Shape Mismatch] {v1.shape} != {v2.shape}"
578+
assert torch.allclose(v1, v2), f"[Checkpoint Param {key} Value Mismatch] {v1} != {v2}"
579+
nonempty_model_state = True
580+
assert nonempty_model_state, "Model state should not be empty for evenly-sharded DTensors!"
581+
582+
# Compare pre-save and post-load optimizer state dictionaries.
583+
o2 = optimizer.state_dict()
584+
nonempty_optim_state = False
585+
for param_id in o1["state"].keys() | o2["state"].keys():
586+
param_state_1 = o1["state"].get(param_id, None)
587+
param_state_2 = o2["state"].get(param_id, None)
588+
assert param_state_1 is not None and param_state_2 is not None, (
589+
f"[{param_id} Not Found] Original Optim State: {param_state_1} | Checkpoint Optim"
590+
f" State: {param_state_2}"
591+
)
592+
for key in param_state_1.keys() | param_state_2.keys():
593+
v1 = param_state_1.get(key, None)
594+
if isinstance(v1, DTensor):
595+
v1 = v1.to_local()
596+
v2 = param_state_2.get(key, None)
597+
if isinstance(v2, DTensor):
598+
v2 = v2.to_local()
599+
assert v1 is not None and v2 is not None, (
600+
f"[{param_id} {key} Not Found] Original Optim State: {v1} | Checkpoint Optim State:"
601+
f" {v2}"
602+
)
603+
assert (
604+
v1.shape == v2.shape
605+
), f"[Optim State {param_id} {key} Shape Mismatch] {v1.shape} != {v2.shape}"
606+
assert torch.allclose(
607+
v1, v2
608+
), f"[Optim State {param_id} {key} Value Mismatch] {v1} != {v2}"
609+
nonempty_optim_state = True # Optimizer state depends on wgrad, verify this!
610+
assert nonempty_optim_state, "Optimizer state should not be empty for evenly-sharded DTensors!"
611+
assert len(o1["param_groups"]) == len(
612+
o2["param_groups"]
613+
), f"[Optim State Param Groups Length Mismatch] {o1['param_groups']} != {o2['param_groups']}"
614+
for i in range(len(o2["param_groups"])):
615+
for key in o1["param_groups"][i].keys():
616+
v1 = o1["param_groups"][i][key]
617+
v2 = o2["param_groups"][i][key]
618+
assert v1 == v2, f"[Optim State Param Group {i} {key} Value Mismatch] {v1} != {v2}"
619+
620+
# Validate post-load model loss.
621+
model.eval()
622+
with te.autocast(enabled=True, recipe=fp8_recipe):
623+
output = model(input_data)
624+
post_load_loss = F.mse_loss(output, target)
625+
# Allow for 1% disparity due to _extra_state disparity.
626+
assert torch.allclose(
627+
pre_save_loss, post_load_loss, rtol=1e-2
628+
), f"Pre-Save Loss: {pre_save_loss} != Post-Load Loss: {post_load_loss}"
629+
630+
# Clean up temporary checkpoint directory.
631+
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
632+
shutil.rmtree(CKPT_DIR)
633+
torch.distributed.barrier()
634+
502635
dist.destroy_process_group()
503636
return 0
504637

tests/pytorch/distributed/test_torch_fsdp2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,7 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type):
8383
[NUM_PROCS],
8484
# HSDP
8585
[2, NUM_PROCS // 2],
86-
# FSDP-TP
87-
[1, 2, NUM_PROCS // 2],
88-
# HSDP-TP
86+
# (H/F)SDP-TP
8987
[NUM_PROCS // 4, 2, 2],
9088
),
9189
)

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,8 @@ def set_device_mesh(
575575
weight_mesh : Optional[DeviceMesh]
576576
Not used for DotProductAttention as there are no quantized weights.
577577
"""
578-
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
578+
if weight_mesh is not None:
579+
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
579580
if tp_mesh is not None:
580581
# Validate TP DeviceMesh / Group. Must be consistent with tp_size.
581582
assert tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), (

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from torch.distributed import DeviceMesh
1313
from torch.distributed.tensor import DTensor
14+
from torch.distributed.tensor.placement_types import Replicate, Shard, _StridedShard
1415

1516
import transformer_engine_torch as tex
1617

@@ -800,13 +801,17 @@ def make_grouped_weights(self, defer_init=False) -> None:
800801
weight_quantizers[0] is None or not weight_quantizers[0].internal
801802
), "Found internal quantizer with `single_grouped_parameter=True`."
802803
grouped_param = torch.nn.Parameter(grouped_weights)
803-
if isinstance(getattr(self, f"weight0", None), DTensor):
804+
if isinstance(getattr(self, "weight0", None), DTensor):
804805
# Convert to DTensor with properties equivalent to the original DTensor.
805-
dtensor_member_param = getattr(self, f"weight0")
806+
dtensor_member_param = getattr(self, "weight0")
807+
grouped_3d_placements = tuple(
808+
type(p)(p.dim + 1) if isinstance(p, (Shard, _StridedShard)) else p
809+
for p in dtensor_member_param.placements
810+
)
806811
grouped_param = _convert_param_to_dtensor_param(
807812
grouped_param,
808813
device_mesh=dtensor_member_param.device_mesh,
809-
placements=dtensor_member_param.placements,
814+
placements=grouped_3d_placements,
810815
# DTensor / DCP will view this as a TP-sharded 3-D Tensor.
811816
shape=(self.num_gemms, self.out_features, self.in_features),
812817
# Default Stride: (out*in, in, 1)
@@ -878,8 +883,6 @@ def set_device_mesh(
878883
self.set_tensor_parallel_group(tp_mesh.get_group())
879884

880885
# Construct TP-sharded DTensors.
881-
from torch.distributed.tensor.placement_types import Replicate, Shard
882-
883886
for weight in self.weight_names:
884887
param = getattr(self, weight)
885888
placements = (Replicate(),)

transformer_engine/pytorch/module/layernorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ def set_device_mesh(
168168
Quantized DTensor parameters are currently not supported for FusibleOperation(s),
169169
and this mesh is not used.
170170
"""
171-
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
171+
if weight_mesh is not None:
172+
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
172173
if tp_mesh is not None:
173174
# Construct TP-Replicate DTensors. Used to shim non-TP parameters for compatibility
174175
# with DTensor parameters in TP layers to support DTensor operations.

transformer_engine/pytorch/module/rmsnorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ def set_device_mesh(
171171
Quantized DTensor parameters are currently not supported for FusibleOperation(s),
172172
and this mesh is not used.
173173
"""
174-
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
174+
if weight_mesh is not None:
175+
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
175176
if tp_mesh is not None:
176177
# Construct TP-Replicate DTensors. Used to shim non-TP parameters for compatibility
177178
# with DTensor parameters in TP layers to support DTensor operations.

0 commit comments

Comments
 (0)