Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 189 additions & 55 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@
import torch._dynamo

import torch.distributed as dist
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.fsdp import (
fully_shard,
)
from torch.distributed.optim import (
ZeroRedundancyOptimizer,
)
Comment thread
OutisLi marked this conversation as resolved.
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import (
DataLoader,
Expand Down Expand Up @@ -131,14 +143,9 @@ def __init__(
self.model_keys = (
list(model_params["model_dict"]) if self.multi_task else ["Default"]
)
self.rank = (
dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
)
self.world_size = (
dist.get_world_size()
if dist.is_available() and dist.is_initialized()
else 1
)
self.is_distributed = dist.is_available() and dist.is_initialized()
self.rank = dist.get_rank() if self.is_distributed else 0
self.world_size = dist.get_world_size() if self.is_distributed else 1
self.num_model = len(self.model_keys)

# Iteration config
Expand All @@ -154,6 +161,19 @@ def __init__(
self.change_bias_after_training = training_params.get(
"change_bias_after_training", False
)
self.zero_stage = int(training_params.get("zero_stage", 0))
Comment thread
OutisLi marked this conversation as resolved.
if self.zero_stage not in (0, 1, 2, 3):
raise ValueError(
f"training.zero_stage must be 0, 1, 2, or 3, got {self.zero_stage}"
)
if self.zero_stage > 0 and not self.is_distributed:
raise ValueError(
"training.zero_stage requires distributed launch via torchrun."
)
if self.zero_stage > 0 and self.change_bias_after_training:
raise ValueError(
"training.zero_stage does not support change_bias_after_training."
)
self.lcurve_should_print_header = True

def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
Expand Down Expand Up @@ -300,6 +320,12 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
)
else:
self.opt_type, self.opt_param = get_opt_param(training_params)
if self.zero_stage > 0 and self.multi_task:
raise ValueError(
"training.zero_stage is currently only supported in single-task training."
)
if self.zero_stage > 0 and self.opt_type == "LKF":
raise ValueError("training.zero_stage does not support LKF optimizer.")

# loss_param_tmp for Hessian activation
loss_param_tmp = None
Expand Down Expand Up @@ -690,15 +716,25 @@ def single_model_finetune(
data_stat_protect=_data_stat_protect[0],
)

if dist.is_available() and dist.is_initialized():
if self.is_distributed:
torch.cuda.set_device(LOCAL_RANK)
# DDP will guarantee the model parameters are identical across all processes
self.wrapper = DDP(
self.wrapper,
device_ids=[LOCAL_RANK],
find_unused_parameters=True,
output_device=LOCAL_RANK,
)
if self.zero_stage >= 2:
# FSDP2 does NOT broadcast params (unlike DDP constructor).
# Ensure all ranks share identical weights before sharding.
for p in self.wrapper.parameters():
dist.broadcast(p.data, src=0)
for b in self.wrapper.buffers():
dist.broadcast(b.data, src=0)
reshard = self.zero_stage >= 3
self.wrapper = fully_shard(self.wrapper, reshard_after_forward=reshard)
else:
# zero_stage=0 or 1: standard DDP (ZeRO-1 will wrap the optimizer)
self.wrapper = DDP(
self.wrapper,
device_ids=[LOCAL_RANK],
find_unused_parameters=True,
output_device=LOCAL_RANK,
)

# TODO add lr warmups for multitask
# author: iProzd
Expand All @@ -714,20 +750,19 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
# author: iProzd
if self.opt_type in ["Adam", "AdamW"]:
if self.opt_type == "Adam":
self.optimizer = torch.optim.Adam(
self.wrapper.parameters(),
self.optimizer = self._create_optimizer(
torch.optim.Adam,
lr=self.lr_exp.start_lr,
fused=False if DEVICE.type == "cpu" else True,
fused=DEVICE.type != "cpu",
)
else:
self.optimizer = torch.optim.AdamW(
self.wrapper.parameters(),
self.optimizer = self._create_optimizer(
torch.optim.AdamW,
lr=self.lr_exp.start_lr,
weight_decay=float(self.opt_param["weight_decay"]),
fused=False if DEVICE.type == "cpu" else True,
fused=DEVICE.type != "cpu",
)
if optimizer_state_dict is not None and self.restart_training:
self.optimizer.load_state_dict(optimizer_state_dict)
self._load_optimizer_state(optimizer_state_dict)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer,
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
Expand All @@ -737,8 +772,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"]
)
elif self.opt_type == "AdaMuon":
self.optimizer = AdaMuonOptimizer(
self.wrapper.parameters(),
self.optimizer = self._create_optimizer(
AdaMuonOptimizer,
lr=self.lr_exp.start_lr,
momentum=float(self.opt_param["momentum"]),
weight_decay=float(self.opt_param["weight_decay"]),
Expand All @@ -750,8 +785,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
)
elif self.opt_type == "HybridMuon":
self.optimizer = HybridMuonOptimizer(
self.wrapper.parameters(),
self.optimizer = self._create_optimizer(
HybridMuonOptimizer,
lr=self.lr_exp.start_lr,
momentum=float(self.opt_param["momentum"]),
weight_decay=float(self.opt_param["weight_decay"]),
Expand All @@ -764,15 +799,25 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
muon_2d_only=bool(self.opt_param["muon_2d_only"]),
min_2d_dim=int(self.opt_param["min_2d_dim"]),
)
if optimizer_state_dict is not None and self.restart_training:
self.optimizer.load_state_dict(optimizer_state_dict)
self._load_optimizer_state(optimizer_state_dict)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer,
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
)
else:
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")

if self.zero_stage > 0 and self.rank == 0:
if self.zero_stage == 1:
log.info("Enabled DDP + ZeRO Stage-1 Optimizer State Sharding.")
else:
stage = (
"FULL_SHARD (Stage 3)"
if self.zero_stage >= 3
else "SHARD_GRAD_OP (Stage 2)"
)
log.info(f"Enabled FSDP2 {stage}.")

# Tensorboard
self.enable_tensorboard = training_params.get("tensorboard", False)
self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log")
Expand Down Expand Up @@ -822,6 +867,58 @@ def _log_parameter_count(self) -> None:
f"Model Params [{model_key}]: {total / 1e6:.3f} M (Trainable: {trainable / 1e6:.3f} M)"
)

def _create_optimizer(
self,
optimizer_class: type[torch.optim.Optimizer],
**kwargs: Any,
) -> torch.optim.Optimizer:
"""
Construct optimizer, wrapping with ZeroRedundancyOptimizer when zero_stage=1.

Parameters
----------
optimizer_class : type[torch.optim.Optimizer]
The optimizer class to instantiate.
**kwargs : Any
Keyword arguments forwarded to the optimizer constructor.

Returns
-------
torch.optim.Optimizer
Constructed optimizer instance.
"""
if self.zero_stage == 1:
return ZeroRedundancyOptimizer(
self.wrapper.parameters(),
optimizer_class=optimizer_class,
**kwargs,
)
return optimizer_class(self.wrapper.parameters(), **kwargs)

def _get_inner_module(self) -> ModelWrapper:
"""Unwrap DDP if needed. FSDP2 is in-place so no unwrapping required."""
if self.is_distributed and self.zero_stage <= 1:
return self.wrapper.module
return self.wrapper

def _load_optimizer_state(
self, optimizer_state_dict: dict[str, Any] | None
) -> None:
"""Load optimizer state for restart training when available."""
if optimizer_state_dict is None or not self.restart_training:
return
if self.zero_stage >= 2:
set_optimizer_state_dict(
self.wrapper,
self.optimizer,
optim_state_dict=optimizer_state_dict,
options=StateDictOptions(
full_state_dict=True, broadcast_from_rank0=True
),
)
else:
self.optimizer.load_state_dict(optimizer_state_dict)

def run(self) -> None:
fout = (
open(
Expand Down Expand Up @@ -892,12 +989,30 @@ def step(_step_id: int, task_key: str = "Default") -> None:
)
loss.backward()
if self.gradient_max_norm > 0.0:
torch.nn.utils.clip_grad_norm_(
# FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead.
total_norm = torch.nn.utils.clip_grad_norm_(
self.wrapper.parameters(),
self.gradient_max_norm,
error_if_nonfinite=True,
)
with torch.device("cpu"):
if not torch.isfinite(total_norm):
bad_params = []
for name, p in self.wrapper.named_parameters():
if p.grad is not None:
grad_norm = p.grad.data.norm()
if not torch.isfinite(grad_norm):
bad_params.append(
f" {name}: grad_norm={grad_norm}, shape={list(p.shape)}"
)
detail = (
"\n".join(bad_params)
if bad_params
else " (all individual grads finite, overflow in norm reduction)"
)
raise RuntimeError(
f"Non-finite gradient norm: {total_norm}\n"
f"Parameters with non-finite gradients:\n{detail}"
)
with torch.device(DEVICE):
self.optimizer.step()
self.scheduler.step()
elif self.opt_type == "LKF":
Expand Down Expand Up @@ -1205,20 +1320,15 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
and _step_id != self.start_step
)
or (display_step_id) == self.num_steps
) and (self.rank == 0 or dist.get_rank() == 0):
) and (self.zero_stage > 0 or self.rank == 0 or dist.get_rank() == 0):
Comment thread
OutisLi marked this conversation as resolved.
# Handle the case if rank 0 aborted and re-assigned
self.latest_model = Path(self.save_ckpt + f"-{display_step_id}.pt")

module = (
self.wrapper.module
if dist.is_available() and dist.is_initialized()
else self.wrapper
)
self.save_model(self.latest_model, lr=cur_lr, step=_step_id)
log.info(f"Saved model to {self.latest_model}")
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))
if self.rank == 0 or dist.get_rank() == 0:
Comment thread
OutisLi marked this conversation as resolved.
log.info(f"Saved model to {self.latest_model}")
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))

# tensorboard
if self.enable_tensorboard and (
Expand Down Expand Up @@ -1273,13 +1383,19 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))

if self.num_steps == 0 and self.zero_stage > 0:
# ZeRO-1 / FSDP: all ranks participate in save_model (collective op)
self.latest_model = Path(self.save_ckpt + "-0.pt")
self.save_model(self.latest_model, lr=0, step=0)

if (
self.rank == 0 or dist.get_rank() == 0
Comment thread
OutisLi marked this conversation as resolved.
): # Handle the case if rank 0 aborted and re-assigned
if self.num_steps == 0:
# when num_steps is 0, the checkpoint is never not saved
self.latest_model = Path(self.save_ckpt + "-0.pt")
self.save_model(self.latest_model, lr=0, step=0)
if self.zero_stage == 0:
# When num_steps is 0, the checkpoint is never saved in the loop
self.latest_model = Path(self.save_ckpt + "-0.pt")
self.save_model(self.latest_model, lr=0, step=0)
log.info(f"Saved model to {self.latest_model}")
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
with open("checkpoint", "w") as f:
Expand Down Expand Up @@ -1321,18 +1437,36 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
)

def save_model(self, save_path: str, lr: float = 0.0, step: int = 0) -> None:
module = (
self.wrapper.module
if dist.is_available() and dist.is_initialized()
else self.wrapper
)
module = self._get_inner_module()
module.train_infos["lr"] = float(lr)
module.train_infos["step"] = step
optim_state_dict = deepcopy(self.optimizer.state_dict())
for item in optim_state_dict["param_groups"]:

# === Collect state dicts ===
if self.zero_stage >= 2:
# FSDP2: collective op, all ranks participate; rank 0 gets full state
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
model_state = get_model_state_dict(self.wrapper, options=options)
optim_state = get_optimizer_state_dict(
self.wrapper, self.optimizer, options=options
)
elif self.zero_stage == 1:
# ZeRO-1: consolidate sharded optimizer state to rank 0
model_state = module.state_dict()
self.optimizer.consolidate_state_dict(to=0)
optim_state = (
deepcopy(self.optimizer.state_dict()) if self.rank == 0 else {}
)
else:
model_state = module.state_dict()
optim_state = deepcopy(self.optimizer.state_dict())

# === Only rank 0 writes to disk ===
if self.rank != 0:
return
for item in optim_state["param_groups"]:
item["lr"] = float(item["lr"])
torch.save(
{"model": module.state_dict(), "optimizer": optim_state_dict},
{"model": model_state, "optimizer": optim_state},
save_path,
)
checkpoint_dir = save_path.parent
Expand Down
Loading