Skip to content

Commit 5bdd765

Browse files
committed
feat(pt): add FSDP & ZeRO1 (Zero Redundancy Optimizer) support
1 parent f6d5d95 commit 5bdd765

3 files changed

Lines changed: 241 additions & 55 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 171 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@
8181
import torch._dynamo
8282

8383
import torch.distributed as dist
84+
from torch.distributed.checkpoint.state_dict import (
85+
StateDictOptions,
86+
get_model_state_dict,
87+
get_optimizer_state_dict,
88+
set_optimizer_state_dict,
89+
)
90+
from torch.distributed.fsdp import (
91+
fully_shard,
92+
)
93+
from torch.distributed.optim import (
94+
ZeroRedundancyOptimizer,
95+
)
8496
from torch.nn.parallel import DistributedDataParallel as DDP
8597
from torch.utils.data import (
8698
DataLoader,
@@ -131,14 +143,9 @@ def __init__(
131143
self.model_keys = (
132144
list(model_params["model_dict"]) if self.multi_task else ["Default"]
133145
)
134-
self.rank = (
135-
dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
136-
)
137-
self.world_size = (
138-
dist.get_world_size()
139-
if dist.is_available() and dist.is_initialized()
140-
else 1
141-
)
146+
self.is_distributed = dist.is_available() and dist.is_initialized()
147+
self.rank = dist.get_rank() if self.is_distributed else 0
148+
self.world_size = dist.get_world_size() if self.is_distributed else 1
142149
self.num_model = len(self.model_keys)
143150

144151
# Iteration config
@@ -154,6 +161,15 @@ def __init__(
154161
self.change_bias_after_training = training_params.get(
155162
"change_bias_after_training", False
156163
)
164+
self.zero_stage = int(training_params.get("zero_stage", 0))
165+
if self.zero_stage > 0 and not self.is_distributed:
166+
raise ValueError(
167+
"training.zero_stage requires distributed launch via torchrun."
168+
)
169+
if self.zero_stage > 0 and self.change_bias_after_training:
170+
raise ValueError(
171+
"training.zero_stage does not support change_bias_after_training."
172+
)
157173
self.lcurve_should_print_header = True
158174

159175
def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
@@ -300,6 +316,12 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
300316
)
301317
else:
302318
self.opt_type, self.opt_param = get_opt_param(training_params)
319+
if self.zero_stage > 0 and self.multi_task:
320+
raise ValueError(
321+
"training.zero_stage is currently only supported in single-task training."
322+
)
323+
if self.zero_stage > 0 and self.opt_type == "LKF":
324+
raise ValueError("training.zero_stage does not support LKF optimizer.")
303325

304326
# loss_param_tmp for Hessian activation
305327
loss_param_tmp = None
@@ -690,15 +712,25 @@ def single_model_finetune(
690712
data_stat_protect=_data_stat_protect[0],
691713
)
692714

693-
if dist.is_available() and dist.is_initialized():
715+
if self.is_distributed:
694716
torch.cuda.set_device(LOCAL_RANK)
695-
# DDP will guarantee the model parameters are identical across all processes
696-
self.wrapper = DDP(
697-
self.wrapper,
698-
device_ids=[LOCAL_RANK],
699-
find_unused_parameters=True,
700-
output_device=LOCAL_RANK,
701-
)
717+
if self.zero_stage >= 2:
718+
# FSDP2 does NOT broadcast params (unlike DDP constructor).
719+
# Ensure all ranks share identical weights before sharding.
720+
for p in self.wrapper.parameters():
721+
dist.broadcast(p.data, src=0)
722+
for b in self.wrapper.buffers():
723+
dist.broadcast(b.data, src=0)
724+
reshard = self.zero_stage >= 3
725+
fully_shard(self.wrapper, reshard_after_forward=reshard)
726+
else:
727+
# zero_stage=0 or 1: standard DDP (ZeRO-1 will wrap the optimizer)
728+
self.wrapper = DDP(
729+
self.wrapper,
730+
device_ids=[LOCAL_RANK],
731+
find_unused_parameters=self.multi_task,
732+
output_device=LOCAL_RANK,
733+
)
702734

703735
# TODO add lr warmups for multitask
704736
# author: iProzd
@@ -714,20 +746,19 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
714746
# author: iProzd
715747
if self.opt_type in ["Adam", "AdamW"]:
716748
if self.opt_type == "Adam":
717-
self.optimizer = torch.optim.Adam(
718-
self.wrapper.parameters(),
749+
self.optimizer = self._create_optimizer(
750+
torch.optim.Adam,
719751
lr=self.lr_exp.start_lr,
720-
fused=False if DEVICE.type == "cpu" else True,
752+
fused=DEVICE.type != "cpu",
721753
)
722754
else:
723-
self.optimizer = torch.optim.AdamW(
724-
self.wrapper.parameters(),
755+
self.optimizer = self._create_optimizer(
756+
torch.optim.AdamW,
725757
lr=self.lr_exp.start_lr,
726758
weight_decay=float(self.opt_param["weight_decay"]),
727-
fused=False if DEVICE.type == "cpu" else True,
759+
fused=DEVICE.type != "cpu",
728760
)
729-
if optimizer_state_dict is not None and self.restart_training:
730-
self.optimizer.load_state_dict(optimizer_state_dict)
761+
self._load_optimizer_state(optimizer_state_dict)
731762
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
732763
self.optimizer,
733764
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
@@ -737,8 +768,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
737768
self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"]
738769
)
739770
elif self.opt_type == "AdaMuon":
740-
self.optimizer = AdaMuonOptimizer(
741-
self.wrapper.parameters(),
771+
self.optimizer = self._create_optimizer(
772+
AdaMuonOptimizer,
742773
lr=self.lr_exp.start_lr,
743774
momentum=float(self.opt_param["momentum"]),
744775
weight_decay=float(self.opt_param["weight_decay"]),
@@ -750,8 +781,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
750781
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
751782
)
752783
elif self.opt_type == "HybridMuon":
753-
self.optimizer = HybridMuonOptimizer(
754-
self.wrapper.parameters(),
784+
self.optimizer = self._create_optimizer(
785+
HybridMuonOptimizer,
755786
lr=self.lr_exp.start_lr,
756787
momentum=float(self.opt_param["momentum"]),
757788
weight_decay=float(self.opt_param["weight_decay"]),
@@ -764,15 +795,25 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
764795
muon_2d_only=bool(self.opt_param["muon_2d_only"]),
765796
min_2d_dim=int(self.opt_param["min_2d_dim"]),
766797
)
767-
if optimizer_state_dict is not None and self.restart_training:
768-
self.optimizer.load_state_dict(optimizer_state_dict)
798+
self._load_optimizer_state(optimizer_state_dict)
769799
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
770800
self.optimizer,
771801
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
772802
)
773803
else:
774804
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
775805

806+
if self.zero_stage > 0 and self.rank == 0:
807+
if self.zero_stage == 1:
808+
log.info("Enabled DDP + ZeRO Stage-1 Optimizer State Sharding.")
809+
else:
810+
stage = (
811+
"FULL_SHARD (Stage 3)"
812+
if self.zero_stage >= 3
813+
else "SHARD_GRAD_OP (Stage 2)"
814+
)
815+
log.info(f"Enabled FSDP2 {stage}.")
816+
776817
# Tensorboard
777818
self.enable_tensorboard = training_params.get("tensorboard", False)
778819
self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log")
@@ -822,6 +863,58 @@ def _log_parameter_count(self) -> None:
822863
f"Model Params [{model_key}]: {total / 1e6:.3f} M (Trainable: {trainable / 1e6:.3f} M)"
823864
)
824865

866+
def _create_optimizer(
867+
self,
868+
optimizer_class: type[torch.optim.Optimizer],
869+
**kwargs: Any,
870+
) -> torch.optim.Optimizer:
871+
"""
872+
Construct optimizer, wrapping with ZeroRedundancyOptimizer when zero_stage=1.
873+
874+
Parameters
875+
----------
876+
optimizer_class : type[torch.optim.Optimizer]
877+
The optimizer class to instantiate.
878+
**kwargs : Any
879+
Keyword arguments forwarded to the optimizer constructor.
880+
881+
Returns
882+
-------
883+
torch.optim.Optimizer
884+
Constructed optimizer instance.
885+
"""
886+
if self.zero_stage == 1:
887+
return ZeroRedundancyOptimizer(
888+
self.wrapper.parameters(),
889+
optimizer_class=optimizer_class,
890+
**kwargs,
891+
)
892+
return optimizer_class(self.wrapper.parameters(), **kwargs)
893+
894+
def _get_inner_module(self) -> ModelWrapper:
895+
"""Unwrap DDP if needed. FSDP2 is in-place so no unwrapping required."""
896+
if self.is_distributed and self.zero_stage <= 1:
897+
return self.wrapper.module
898+
return self.wrapper
899+
900+
def _load_optimizer_state(
901+
self, optimizer_state_dict: dict[str, Any] | None
902+
) -> None:
903+
"""Load optimizer state for restart training when available."""
904+
if optimizer_state_dict is None or not self.restart_training:
905+
return
906+
if self.zero_stage >= 2:
907+
set_optimizer_state_dict(
908+
self.wrapper,
909+
self.optimizer,
910+
optim_state_dict=optimizer_state_dict,
911+
options=StateDictOptions(
912+
full_state_dict=True, broadcast_from_rank0=True
913+
),
914+
)
915+
else:
916+
self.optimizer.load_state_dict(optimizer_state_dict)
917+
825918
def run(self) -> None:
826919
fout = (
827920
open(
@@ -892,12 +985,16 @@ def step(_step_id: int, task_key: str = "Default") -> None:
892985
)
893986
loss.backward()
894987
if self.gradient_max_norm > 0.0:
895-
torch.nn.utils.clip_grad_norm_(
988+
# Avoid error_if_nonfinite=True: FSDP2 sharded
989+
# DTensor gradients may not support it. Manual
990+
# isfinite check achieves the same fail-fast behavior.
991+
total_norm = torch.nn.utils.clip_grad_norm_(
896992
self.wrapper.parameters(),
897993
self.gradient_max_norm,
898-
error_if_nonfinite=True,
899994
)
900-
with torch.device("cpu"):
995+
if not torch.isfinite(total_norm):
996+
raise RuntimeError(f"Non-finite gradient norm: {total_norm}")
997+
with torch.device(DEVICE):
901998
self.optimizer.step()
902999
self.scheduler.step()
9031000
elif self.opt_type == "LKF":
@@ -1205,20 +1302,15 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
12051302
and _step_id != self.start_step
12061303
)
12071304
or (display_step_id) == self.num_steps
1208-
) and (self.rank == 0 or dist.get_rank() == 0):
1305+
) and (self.zero_stage > 0 or self.rank == 0 or dist.get_rank() == 0):
12091306
# Handle the case if rank 0 aborted and re-assigned
12101307
self.latest_model = Path(self.save_ckpt + f"-{display_step_id}.pt")
1211-
1212-
module = (
1213-
self.wrapper.module
1214-
if dist.is_available() and dist.is_initialized()
1215-
else self.wrapper
1216-
)
12171308
self.save_model(self.latest_model, lr=cur_lr, step=_step_id)
1218-
log.info(f"Saved model to {self.latest_model}")
1219-
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
1220-
with open("checkpoint", "w") as f:
1221-
f.write(str(self.latest_model))
1309+
if self.rank == 0 or dist.get_rank() == 0:
1310+
log.info(f"Saved model to {self.latest_model}")
1311+
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
1312+
with open("checkpoint", "w") as f:
1313+
f.write(str(self.latest_model))
12221314

12231315
# tensorboard
12241316
if self.enable_tensorboard and (
@@ -1273,13 +1365,19 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
12731365
with open("checkpoint", "w") as f:
12741366
f.write(str(self.latest_model))
12751367

1368+
if self.num_steps == 0 and self.zero_stage > 0:
1369+
# ZeRO-1 / FSDP: all ranks participate in save_model (collective op)
1370+
self.latest_model = Path(self.save_ckpt + "-0.pt")
1371+
self.save_model(self.latest_model, lr=0, step=0)
1372+
12761373
if (
12771374
self.rank == 0 or dist.get_rank() == 0
12781375
): # Handle the case if rank 0 aborted and re-assigned
12791376
if self.num_steps == 0:
1280-
# when num_steps is 0, the checkpoint is never not saved
1281-
self.latest_model = Path(self.save_ckpt + "-0.pt")
1282-
self.save_model(self.latest_model, lr=0, step=0)
1377+
if self.zero_stage == 0:
1378+
# When num_steps is 0, the checkpoint is never saved in the loop
1379+
self.latest_model = Path(self.save_ckpt + "-0.pt")
1380+
self.save_model(self.latest_model, lr=0, step=0)
12831381
log.info(f"Saved model to {self.latest_model}")
12841382
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
12851383
with open("checkpoint", "w") as f:
@@ -1321,18 +1419,36 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13211419
)
13221420

13231421
def save_model(self, save_path: str, lr: float = 0.0, step: int = 0) -> None:
1324-
module = (
1325-
self.wrapper.module
1326-
if dist.is_available() and dist.is_initialized()
1327-
else self.wrapper
1328-
)
1422+
module = self._get_inner_module()
13291423
module.train_infos["lr"] = float(lr)
13301424
module.train_infos["step"] = step
1331-
optim_state_dict = deepcopy(self.optimizer.state_dict())
1332-
for item in optim_state_dict["param_groups"]:
1425+
1426+
# === Collect state dicts ===
1427+
if self.zero_stage >= 2:
1428+
# FSDP2: collective op, all ranks participate; rank 0 gets full state
1429+
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
1430+
model_state = get_model_state_dict(self.wrapper, options=options)
1431+
optim_state = get_optimizer_state_dict(
1432+
self.wrapper, self.optimizer, options=options
1433+
)
1434+
elif self.zero_stage == 1:
1435+
# ZeRO-1: consolidate sharded optimizer state to rank 0
1436+
model_state = module.state_dict()
1437+
self.optimizer.consolidate_state_dict(to=0)
1438+
optim_state = (
1439+
deepcopy(self.optimizer.state_dict()) if self.rank == 0 else {}
1440+
)
1441+
else:
1442+
model_state = module.state_dict()
1443+
optim_state = deepcopy(self.optimizer.state_dict())
1444+
1445+
# === Only rank 0 writes to disk ===
1446+
if self.rank != 0:
1447+
return
1448+
for item in optim_state["param_groups"]:
13331449
item["lr"] = float(item["lr"])
13341450
torch.save(
1335-
{"model": module.state_dict(), "optimizer": optim_state_dict},
1451+
{"model": model_state, "optimizer": optim_state},
13361452
save_path,
13371453
)
13381454
checkpoint_dir = save_path.parent

deepmd/utils/argcheck.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3273,6 +3273,23 @@ def training_args(
32733273
doc_model_prob = "The visiting probability of each model for each training step in the multi-task mode."
32743274
doc_data_dict = "The multiple definition of the data, used in the multi-task mode."
32753275
doc_acc_freq = "Gradient accumulation steps (number of steps to accumulate gradients before performing an update)."
3276+
doc_zero_stage = (
3277+
"ZeRO optimization stage for distributed training memory reduction. "
3278+
"0: standard DDP, lowest communication overhead but highest memory usage "
3279+
"(full optimizer states, gradients, and parameters replicated on every GPU). "
3280+
"1: DDP + ZeRO stage-1, shards optimizer states across GPUs via "
3281+
"ZeroRedundancyOptimizer; same communication volume as DDP (2x model size) "
3282+
"but reduces optimizer memory to 1/N per GPU. "
3283+
"2: FSDP2 stage-2, shards optimizer states and gradients; same communication "
3284+
"volume as stage-1 but further reduces gradient memory to 1/N per GPU. "
3285+
"Note: FSDP2 introduces DTensor dispatch overhead that can slow down "
3286+
"models with many small layers; use torch.compile to mitigate. "
3287+
"3: FSDP2 stage-3, shards parameters as well; maximum memory savings but "
3288+
"50% more communication (3x model size) due to parameter all-gather in "
3289+
"both forward and backward passes. "
3290+
"Default is 0. Requires distributed launch via torchrun. "
3291+
"Currently supports single-task training; does not support LKF or change_bias_after_training."
3292+
)
32763293

32773294
arg_training_data = training_data_args()
32783295
arg_validation_data = validation_data_args()
@@ -3395,6 +3412,13 @@ def training_args(
33953412
default=1,
33963413
doc=doc_only_pd_supported + doc_acc_freq,
33973414
),
3415+
Argument(
3416+
"zero_stage",
3417+
int,
3418+
optional=True,
3419+
default=0,
3420+
doc=doc_only_pt_supported + doc_zero_stage,
3421+
),
33983422
]
33993423
variants = [
34003424
Variant(

0 commit comments

Comments
 (0)