Skip to content

Commit 09e40bb

Browse files
feat(pt): add FSDP & ZeRO1 (Zero Redundancy Optimizer) support (#5222)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Configurable ZeRO/FSDP memory-optimization stages (zero_stage 0–3) with FSDP/optimizer-state sharding, synchronized parameter broadcasts, and consistent rank-0 checkpoint writes. * Runtime logs indicating optimizer/sharding mode and automatic distributed initialization when enabled. * **Bug Fixes** * Validation and enforced constraints for incompatible combos (single-task required for zero_stage>0, LKF and certain bias-change flows disallowed). * **Documentation** * Expanded PyTorch parallel-training docs with stage guidance, example config, constraints, and launch notes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: OutisLi <137472077+OutisLi@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent e23443f commit 09e40bb

3 files changed

Lines changed: 259 additions & 55 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 189 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@
8181
import torch._dynamo
8282

8383
import torch.distributed as dist
84+
from torch.distributed.checkpoint.state_dict import (
85+
StateDictOptions,
86+
get_model_state_dict,
87+
get_optimizer_state_dict,
88+
set_optimizer_state_dict,
89+
)
90+
from torch.distributed.fsdp import (
91+
fully_shard,
92+
)
93+
from torch.distributed.optim import (
94+
ZeroRedundancyOptimizer,
95+
)
8496
from torch.nn.parallel import DistributedDataParallel as DDP
8597
from torch.utils.data import (
8698
DataLoader,
@@ -131,14 +143,9 @@ def __init__(
131143
self.model_keys = (
132144
list(model_params["model_dict"]) if self.multi_task else ["Default"]
133145
)
134-
self.rank = (
135-
dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
136-
)
137-
self.world_size = (
138-
dist.get_world_size()
139-
if dist.is_available() and dist.is_initialized()
140-
else 1
141-
)
146+
self.is_distributed = dist.is_available() and dist.is_initialized()
147+
self.rank = dist.get_rank() if self.is_distributed else 0
148+
self.world_size = dist.get_world_size() if self.is_distributed else 1
142149
self.num_model = len(self.model_keys)
143150

144151
# Iteration config
@@ -154,6 +161,19 @@ def __init__(
154161
self.change_bias_after_training = training_params.get(
155162
"change_bias_after_training", False
156163
)
164+
self.zero_stage = int(training_params.get("zero_stage", 0))
165+
if self.zero_stage not in (0, 1, 2, 3):
166+
raise ValueError(
167+
f"training.zero_stage must be 0, 1, 2, or 3, got {self.zero_stage}"
168+
)
169+
if self.zero_stage > 0 and not self.is_distributed:
170+
raise ValueError(
171+
"training.zero_stage requires distributed launch via torchrun."
172+
)
173+
if self.zero_stage > 0 and self.change_bias_after_training:
174+
raise ValueError(
175+
"training.zero_stage does not support change_bias_after_training."
176+
)
157177
self.lcurve_should_print_header = True
158178

159179
def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
@@ -300,6 +320,12 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
300320
)
301321
else:
302322
self.opt_type, self.opt_param = get_opt_param(training_params)
323+
if self.zero_stage > 0 and self.multi_task:
324+
raise ValueError(
325+
"training.zero_stage is currently only supported in single-task training."
326+
)
327+
if self.zero_stage > 0 and self.opt_type == "LKF":
328+
raise ValueError("training.zero_stage does not support LKF optimizer.")
303329

304330
# loss_param_tmp for Hessian activation
305331
loss_param_tmp = None
@@ -690,15 +716,25 @@ def single_model_finetune(
690716
data_stat_protect=_data_stat_protect[0],
691717
)
692718

693-
if dist.is_available() and dist.is_initialized():
719+
if self.is_distributed:
694720
torch.cuda.set_device(LOCAL_RANK)
695-
# DDP will guarantee the model parameters are identical across all processes
696-
self.wrapper = DDP(
697-
self.wrapper,
698-
device_ids=[LOCAL_RANK],
699-
find_unused_parameters=True,
700-
output_device=LOCAL_RANK,
701-
)
721+
if self.zero_stage >= 2:
722+
# FSDP2 does NOT broadcast params (unlike DDP constructor).
723+
# Ensure all ranks share identical weights before sharding.
724+
for p in self.wrapper.parameters():
725+
dist.broadcast(p.data, src=0)
726+
for b in self.wrapper.buffers():
727+
dist.broadcast(b.data, src=0)
728+
reshard = self.zero_stage >= 3
729+
self.wrapper = fully_shard(self.wrapper, reshard_after_forward=reshard)
730+
else:
731+
# zero_stage=0 or 1: standard DDP (ZeRO-1 will wrap the optimizer)
732+
self.wrapper = DDP(
733+
self.wrapper,
734+
device_ids=[LOCAL_RANK],
735+
find_unused_parameters=True,
736+
output_device=LOCAL_RANK,
737+
)
702738

703739
# TODO add lr warmups for multitask
704740
# author: iProzd
@@ -714,20 +750,19 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
714750
# author: iProzd
715751
if self.opt_type in ["Adam", "AdamW"]:
716752
if self.opt_type == "Adam":
717-
self.optimizer = torch.optim.Adam(
718-
self.wrapper.parameters(),
753+
self.optimizer = self._create_optimizer(
754+
torch.optim.Adam,
719755
lr=self.lr_exp.start_lr,
720-
fused=False if DEVICE.type == "cpu" else True,
756+
fused=DEVICE.type != "cpu",
721757
)
722758
else:
723-
self.optimizer = torch.optim.AdamW(
724-
self.wrapper.parameters(),
759+
self.optimizer = self._create_optimizer(
760+
torch.optim.AdamW,
725761
lr=self.lr_exp.start_lr,
726762
weight_decay=float(self.opt_param["weight_decay"]),
727-
fused=False if DEVICE.type == "cpu" else True,
763+
fused=DEVICE.type != "cpu",
728764
)
729-
if optimizer_state_dict is not None and self.restart_training:
730-
self.optimizer.load_state_dict(optimizer_state_dict)
765+
self._load_optimizer_state(optimizer_state_dict)
731766
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
732767
self.optimizer,
733768
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
@@ -737,8 +772,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
737772
self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"]
738773
)
739774
elif self.opt_type == "AdaMuon":
740-
self.optimizer = AdaMuonOptimizer(
741-
self.wrapper.parameters(),
775+
self.optimizer = self._create_optimizer(
776+
AdaMuonOptimizer,
742777
lr=self.lr_exp.start_lr,
743778
momentum=float(self.opt_param["momentum"]),
744779
weight_decay=float(self.opt_param["weight_decay"]),
@@ -750,8 +785,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
750785
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
751786
)
752787
elif self.opt_type == "HybridMuon":
753-
self.optimizer = HybridMuonOptimizer(
754-
self.wrapper.parameters(),
788+
self.optimizer = self._create_optimizer(
789+
HybridMuonOptimizer,
755790
lr=self.lr_exp.start_lr,
756791
momentum=float(self.opt_param["momentum"]),
757792
weight_decay=float(self.opt_param["weight_decay"]),
@@ -764,15 +799,25 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
764799
muon_2d_only=bool(self.opt_param["muon_2d_only"]),
765800
min_2d_dim=int(self.opt_param["min_2d_dim"]),
766801
)
767-
if optimizer_state_dict is not None and self.restart_training:
768-
self.optimizer.load_state_dict(optimizer_state_dict)
802+
self._load_optimizer_state(optimizer_state_dict)
769803
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
770804
self.optimizer,
771805
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
772806
)
773807
else:
774808
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
775809

810+
if self.zero_stage > 0 and self.rank == 0:
811+
if self.zero_stage == 1:
812+
log.info("Enabled DDP + ZeRO Stage-1 Optimizer State Sharding.")
813+
else:
814+
stage = (
815+
"FULL_SHARD (Stage 3)"
816+
if self.zero_stage >= 3
817+
else "SHARD_GRAD_OP (Stage 2)"
818+
)
819+
log.info(f"Enabled FSDP2 {stage}.")
820+
776821
# Tensorboard
777822
self.enable_tensorboard = training_params.get("tensorboard", False)
778823
self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log")
@@ -822,6 +867,58 @@ def _log_parameter_count(self) -> None:
822867
f"Model Params [{model_key}]: {total / 1e6:.3f} M (Trainable: {trainable / 1e6:.3f} M)"
823868
)
824869

870+
def _create_optimizer(
871+
self,
872+
optimizer_class: type[torch.optim.Optimizer],
873+
**kwargs: Any,
874+
) -> torch.optim.Optimizer:
875+
"""
876+
Construct optimizer, wrapping with ZeroRedundancyOptimizer when zero_stage=1.
877+
878+
Parameters
879+
----------
880+
optimizer_class : type[torch.optim.Optimizer]
881+
The optimizer class to instantiate.
882+
**kwargs : Any
883+
Keyword arguments forwarded to the optimizer constructor.
884+
885+
Returns
886+
-------
887+
torch.optim.Optimizer
888+
Constructed optimizer instance.
889+
"""
890+
if self.zero_stage == 1:
891+
return ZeroRedundancyOptimizer(
892+
self.wrapper.parameters(),
893+
optimizer_class=optimizer_class,
894+
**kwargs,
895+
)
896+
return optimizer_class(self.wrapper.parameters(), **kwargs)
897+
898+
def _get_inner_module(self) -> ModelWrapper:
899+
"""Unwrap DDP if needed. FSDP2 is in-place so no unwrapping required."""
900+
if self.is_distributed and self.zero_stage <= 1:
901+
return self.wrapper.module
902+
return self.wrapper
903+
904+
def _load_optimizer_state(
905+
self, optimizer_state_dict: dict[str, Any] | None
906+
) -> None:
907+
"""Load optimizer state for restart training when available."""
908+
if optimizer_state_dict is None or not self.restart_training:
909+
return
910+
if self.zero_stage >= 2:
911+
set_optimizer_state_dict(
912+
self.wrapper,
913+
self.optimizer,
914+
optim_state_dict=optimizer_state_dict,
915+
options=StateDictOptions(
916+
full_state_dict=True, broadcast_from_rank0=True
917+
),
918+
)
919+
else:
920+
self.optimizer.load_state_dict(optimizer_state_dict)
921+
825922
def run(self) -> None:
826923
fout = (
827924
open(
@@ -892,12 +989,30 @@ def step(_step_id: int, task_key: str = "Default") -> None:
892989
)
893990
loss.backward()
894991
if self.gradient_max_norm > 0.0:
895-
torch.nn.utils.clip_grad_norm_(
992+
# FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead.
993+
total_norm = torch.nn.utils.clip_grad_norm_(
896994
self.wrapper.parameters(),
897995
self.gradient_max_norm,
898-
error_if_nonfinite=True,
899996
)
900-
with torch.device("cpu"):
997+
if not torch.isfinite(total_norm):
998+
bad_params = []
999+
for name, p in self.wrapper.named_parameters():
1000+
if p.grad is not None:
1001+
grad_norm = p.grad.data.norm()
1002+
if not torch.isfinite(grad_norm):
1003+
bad_params.append(
1004+
f" {name}: grad_norm={grad_norm}, shape={list(p.shape)}"
1005+
)
1006+
detail = (
1007+
"\n".join(bad_params)
1008+
if bad_params
1009+
else " (all individual grads finite, overflow in norm reduction)"
1010+
)
1011+
raise RuntimeError(
1012+
f"Non-finite gradient norm: {total_norm}\n"
1013+
f"Parameters with non-finite gradients:\n{detail}"
1014+
)
1015+
with torch.device(DEVICE):
9011016
self.optimizer.step()
9021017
self.scheduler.step()
9031018
elif self.opt_type == "LKF":
@@ -1205,20 +1320,15 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
12051320
and _step_id != self.start_step
12061321
)
12071322
or (display_step_id) == self.num_steps
1208-
) and (self.rank == 0 or dist.get_rank() == 0):
1323+
) and (self.zero_stage > 0 or self.rank == 0 or dist.get_rank() == 0):
12091324
# Handle the case if rank 0 aborted and re-assigned
12101325
self.latest_model = Path(self.save_ckpt + f"-{display_step_id}.pt")
1211-
1212-
module = (
1213-
self.wrapper.module
1214-
if dist.is_available() and dist.is_initialized()
1215-
else self.wrapper
1216-
)
12171326
self.save_model(self.latest_model, lr=cur_lr, step=_step_id)
1218-
log.info(f"Saved model to {self.latest_model}")
1219-
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
1220-
with open("checkpoint", "w") as f:
1221-
f.write(str(self.latest_model))
1327+
if self.rank == 0 or dist.get_rank() == 0:
1328+
log.info(f"Saved model to {self.latest_model}")
1329+
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
1330+
with open("checkpoint", "w") as f:
1331+
f.write(str(self.latest_model))
12221332

12231333
# tensorboard
12241334
if self.enable_tensorboard and (
@@ -1273,13 +1383,19 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
12731383
with open("checkpoint", "w") as f:
12741384
f.write(str(self.latest_model))
12751385

1386+
if self.num_steps == 0 and self.zero_stage > 0:
1387+
# ZeRO-1 / FSDP: all ranks participate in save_model (collective op)
1388+
self.latest_model = Path(self.save_ckpt + "-0.pt")
1389+
self.save_model(self.latest_model, lr=0, step=0)
1390+
12761391
if (
12771392
self.rank == 0 or dist.get_rank() == 0
12781393
): # Handle the case if rank 0 aborted and re-assigned
12791394
if self.num_steps == 0:
1280-
# when num_steps is 0, the checkpoint is never not saved
1281-
self.latest_model = Path(self.save_ckpt + "-0.pt")
1282-
self.save_model(self.latest_model, lr=0, step=0)
1395+
if self.zero_stage == 0:
1396+
# When num_steps is 0, the checkpoint is never saved in the loop
1397+
self.latest_model = Path(self.save_ckpt + "-0.pt")
1398+
self.save_model(self.latest_model, lr=0, step=0)
12831399
log.info(f"Saved model to {self.latest_model}")
12841400
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
12851401
with open("checkpoint", "w") as f:
@@ -1321,18 +1437,36 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13211437
)
13221438

13231439
def save_model(self, save_path: str, lr: float = 0.0, step: int = 0) -> None:
1324-
module = (
1325-
self.wrapper.module
1326-
if dist.is_available() and dist.is_initialized()
1327-
else self.wrapper
1328-
)
1440+
module = self._get_inner_module()
13291441
module.train_infos["lr"] = float(lr)
13301442
module.train_infos["step"] = step
1331-
optim_state_dict = deepcopy(self.optimizer.state_dict())
1332-
for item in optim_state_dict["param_groups"]:
1443+
1444+
# === Collect state dicts ===
1445+
if self.zero_stage >= 2:
1446+
# FSDP2: collective op, all ranks participate; rank 0 gets full state
1447+
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
1448+
model_state = get_model_state_dict(self.wrapper, options=options)
1449+
optim_state = get_optimizer_state_dict(
1450+
self.wrapper, self.optimizer, options=options
1451+
)
1452+
elif self.zero_stage == 1:
1453+
# ZeRO-1: consolidate sharded optimizer state to rank 0
1454+
model_state = module.state_dict()
1455+
self.optimizer.consolidate_state_dict(to=0)
1456+
optim_state = (
1457+
deepcopy(self.optimizer.state_dict()) if self.rank == 0 else {}
1458+
)
1459+
else:
1460+
model_state = module.state_dict()
1461+
optim_state = deepcopy(self.optimizer.state_dict())
1462+
1463+
# === Only rank 0 writes to disk ===
1464+
if self.rank != 0:
1465+
return
1466+
for item in optim_state["param_groups"]:
13331467
item["lr"] = float(item["lr"])
13341468
torch.save(
1335-
{"model": module.state_dict(), "optimizer": optim_state_dict},
1469+
{"model": model_state, "optimizer": optim_state},
13361470
save_path,
13371471
)
13381472
checkpoint_dir = save_path.parent

0 commit comments

Comments
 (0)