diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 5b35206661..be480fda3f 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -81,6 +81,18 @@ import torch._dynamo import torch.distributed as dist +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_optimizer_state_dict, + set_optimizer_state_dict, +) +from torch.distributed.fsdp import ( + fully_shard, +) +from torch.distributed.optim import ( + ZeroRedundancyOptimizer, +) from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import ( DataLoader, @@ -131,14 +143,9 @@ def __init__( self.model_keys = ( list(model_params["model_dict"]) if self.multi_task else ["Default"] ) - self.rank = ( - dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 - ) - self.world_size = ( - dist.get_world_size() - if dist.is_available() and dist.is_initialized() - else 1 - ) + self.is_distributed = dist.is_available() and dist.is_initialized() + self.rank = dist.get_rank() if self.is_distributed else 0 + self.world_size = dist.get_world_size() if self.is_distributed else 1 self.num_model = len(self.model_keys) # Iteration config @@ -154,6 +161,19 @@ def __init__( self.change_bias_after_training = training_params.get( "change_bias_after_training", False ) + self.zero_stage = int(training_params.get("zero_stage", 0)) + if self.zero_stage not in (0, 1, 2, 3): + raise ValueError( + f"training.zero_stage must be 0, 1, 2, or 3, got {self.zero_stage}" + ) + if self.zero_stage > 0 and not self.is_distributed: + raise ValueError( + "training.zero_stage requires distributed launch via torchrun." + ) + if self.zero_stage > 0 and self.change_bias_after_training: + raise ValueError( + "training.zero_stage does not support change_bias_after_training." + ) self.lcurve_should_print_header = True def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: @@ -300,6 +320,12 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: ) else: self.opt_type, self.opt_param = get_opt_param(training_params) + if self.zero_stage > 0 and self.multi_task: + raise ValueError( + "training.zero_stage is currently only supported in single-task training." + ) + if self.zero_stage > 0 and self.opt_type == "LKF": + raise ValueError("training.zero_stage does not support LKF optimizer.") # loss_param_tmp for Hessian activation loss_param_tmp = None @@ -690,15 +716,25 @@ def single_model_finetune( data_stat_protect=_data_stat_protect[0], ) - if dist.is_available() and dist.is_initialized(): + if self.is_distributed: torch.cuda.set_device(LOCAL_RANK) - # DDP will guarantee the model parameters are identical across all processes - self.wrapper = DDP( - self.wrapper, - device_ids=[LOCAL_RANK], - find_unused_parameters=True, - output_device=LOCAL_RANK, - ) + if self.zero_stage >= 2: + # FSDP2 does NOT broadcast params (unlike DDP constructor). + # Ensure all ranks share identical weights before sharding. + for p in self.wrapper.parameters(): + dist.broadcast(p.data, src=0) + for b in self.wrapper.buffers(): + dist.broadcast(b.data, src=0) + reshard = self.zero_stage >= 3 + self.wrapper = fully_shard(self.wrapper, reshard_after_forward=reshard) + else: + # zero_stage=0 or 1: standard DDP (ZeRO-1 will wrap the optimizer) + self.wrapper = DDP( + self.wrapper, + device_ids=[LOCAL_RANK], + find_unused_parameters=True, + output_device=LOCAL_RANK, + ) # TODO add lr warmups for multitask # author: iProzd @@ -714,20 +750,19 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: # author: iProzd if self.opt_type in ["Adam", "AdamW"]: if self.opt_type == "Adam": - self.optimizer = torch.optim.Adam( - self.wrapper.parameters(), + self.optimizer = self._create_optimizer( + torch.optim.Adam, lr=self.lr_exp.start_lr, - fused=False if DEVICE.type == "cpu" else True, + fused=DEVICE.type != "cpu", ) else: - self.optimizer = torch.optim.AdamW( - self.wrapper.parameters(), + self.optimizer = self._create_optimizer( + torch.optim.AdamW, lr=self.lr_exp.start_lr, weight_decay=float(self.opt_param["weight_decay"]), - fused=False if DEVICE.type == "cpu" else True, + fused=DEVICE.type != "cpu", ) - if optimizer_state_dict is not None and self.restart_training: - self.optimizer.load_state_dict(optimizer_state_dict) + self._load_optimizer_state(optimizer_state_dict) self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), @@ -737,8 +772,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"] ) elif self.opt_type == "AdaMuon": - self.optimizer = AdaMuonOptimizer( - self.wrapper.parameters(), + self.optimizer = self._create_optimizer( + AdaMuonOptimizer, lr=self.lr_exp.start_lr, momentum=float(self.opt_param["momentum"]), weight_decay=float(self.opt_param["weight_decay"]), @@ -750,8 +785,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]), ) elif self.opt_type == "HybridMuon": - self.optimizer = HybridMuonOptimizer( - self.wrapper.parameters(), + self.optimizer = self._create_optimizer( + HybridMuonOptimizer, lr=self.lr_exp.start_lr, momentum=float(self.opt_param["momentum"]), weight_decay=float(self.opt_param["weight_decay"]), @@ -764,8 +799,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: muon_2d_only=bool(self.opt_param["muon_2d_only"]), min_2d_dim=int(self.opt_param["min_2d_dim"]), ) - if optimizer_state_dict is not None and self.restart_training: - self.optimizer.load_state_dict(optimizer_state_dict) + self._load_optimizer_state(optimizer_state_dict) self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), @@ -773,6 +807,17 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") + if self.zero_stage > 0 and self.rank == 0: + if self.zero_stage == 1: + log.info("Enabled DDP + ZeRO Stage-1 Optimizer State Sharding.") + else: + stage = ( + "FULL_SHARD (Stage 3)" + if self.zero_stage >= 3 + else "SHARD_GRAD_OP (Stage 2)" + ) + log.info(f"Enabled FSDP2 {stage}.") + # Tensorboard self.enable_tensorboard = training_params.get("tensorboard", False) self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log") @@ -822,6 +867,58 @@ def _log_parameter_count(self) -> None: f"Model Params [{model_key}]: {total / 1e6:.3f} M (Trainable: {trainable / 1e6:.3f} M)" ) + def _create_optimizer( + self, + optimizer_class: type[torch.optim.Optimizer], + **kwargs: Any, + ) -> torch.optim.Optimizer: + """ + Construct optimizer, wrapping with ZeroRedundancyOptimizer when zero_stage=1. + + Parameters + ---------- + optimizer_class : type[torch.optim.Optimizer] + The optimizer class to instantiate. + **kwargs : Any + Keyword arguments forwarded to the optimizer constructor. + + Returns + ------- + torch.optim.Optimizer + Constructed optimizer instance. + """ + if self.zero_stage == 1: + return ZeroRedundancyOptimizer( + self.wrapper.parameters(), + optimizer_class=optimizer_class, + **kwargs, + ) + return optimizer_class(self.wrapper.parameters(), **kwargs) + + def _get_inner_module(self) -> ModelWrapper: + """Unwrap DDP if needed. FSDP2 is in-place so no unwrapping required.""" + if self.is_distributed and self.zero_stage <= 1: + return self.wrapper.module + return self.wrapper + + def _load_optimizer_state( + self, optimizer_state_dict: dict[str, Any] | None + ) -> None: + """Load optimizer state for restart training when available.""" + if optimizer_state_dict is None or not self.restart_training: + return + if self.zero_stage >= 2: + set_optimizer_state_dict( + self.wrapper, + self.optimizer, + optim_state_dict=optimizer_state_dict, + options=StateDictOptions( + full_state_dict=True, broadcast_from_rank0=True + ), + ) + else: + self.optimizer.load_state_dict(optimizer_state_dict) + def run(self) -> None: fout = ( open( @@ -892,12 +989,30 @@ def step(_step_id: int, task_key: str = "Default") -> None: ) loss.backward() if self.gradient_max_norm > 0.0: - torch.nn.utils.clip_grad_norm_( + # FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead. + total_norm = torch.nn.utils.clip_grad_norm_( self.wrapper.parameters(), self.gradient_max_norm, - error_if_nonfinite=True, ) - with torch.device("cpu"): + if not torch.isfinite(total_norm): + bad_params = [] + for name, p in self.wrapper.named_parameters(): + if p.grad is not None: + grad_norm = p.grad.data.norm() + if not torch.isfinite(grad_norm): + bad_params.append( + f" {name}: grad_norm={grad_norm}, shape={list(p.shape)}" + ) + detail = ( + "\n".join(bad_params) + if bad_params + else " (all individual grads finite, overflow in norm reduction)" + ) + raise RuntimeError( + f"Non-finite gradient norm: {total_norm}\n" + f"Parameters with non-finite gradients:\n{detail}" + ) + with torch.device(DEVICE): self.optimizer.step() self.scheduler.step() elif self.opt_type == "LKF": @@ -1205,20 +1320,15 @@ def log_loss_valid(_task_key: str = "Default") -> dict: and _step_id != self.start_step ) or (display_step_id) == self.num_steps - ) and (self.rank == 0 or dist.get_rank() == 0): + ) and (self.zero_stage > 0 or self.rank == 0 or dist.get_rank() == 0): # Handle the case if rank 0 aborted and re-assigned self.latest_model = Path(self.save_ckpt + f"-{display_step_id}.pt") - - module = ( - self.wrapper.module - if dist.is_available() and dist.is_initialized() - else self.wrapper - ) self.save_model(self.latest_model, lr=cur_lr, step=_step_id) - log.info(f"Saved model to {self.latest_model}") - symlink_prefix_files(self.latest_model.stem, self.save_ckpt) - with open("checkpoint", "w") as f: - f.write(str(self.latest_model)) + if self.rank == 0 or dist.get_rank() == 0: + log.info(f"Saved model to {self.latest_model}") + symlink_prefix_files(self.latest_model.stem, self.save_ckpt) + with open("checkpoint", "w") as f: + f.write(str(self.latest_model)) # tensorboard if self.enable_tensorboard and ( @@ -1273,13 +1383,19 @@ def log_loss_valid(_task_key: str = "Default") -> dict: with open("checkpoint", "w") as f: f.write(str(self.latest_model)) + if self.num_steps == 0 and self.zero_stage > 0: + # ZeRO-1 / FSDP: all ranks participate in save_model (collective op) + self.latest_model = Path(self.save_ckpt + "-0.pt") + self.save_model(self.latest_model, lr=0, step=0) + if ( self.rank == 0 or dist.get_rank() == 0 ): # Handle the case if rank 0 aborted and re-assigned if self.num_steps == 0: - # when num_steps is 0, the checkpoint is never not saved - self.latest_model = Path(self.save_ckpt + "-0.pt") - self.save_model(self.latest_model, lr=0, step=0) + if self.zero_stage == 0: + # When num_steps is 0, the checkpoint is never saved in the loop + self.latest_model = Path(self.save_ckpt + "-0.pt") + self.save_model(self.latest_model, lr=0, step=0) log.info(f"Saved model to {self.latest_model}") symlink_prefix_files(self.latest_model.stem, self.save_ckpt) with open("checkpoint", "w") as f: @@ -1321,18 +1437,36 @@ def log_loss_valid(_task_key: str = "Default") -> dict: ) def save_model(self, save_path: str, lr: float = 0.0, step: int = 0) -> None: - module = ( - self.wrapper.module - if dist.is_available() and dist.is_initialized() - else self.wrapper - ) + module = self._get_inner_module() module.train_infos["lr"] = float(lr) module.train_infos["step"] = step - optim_state_dict = deepcopy(self.optimizer.state_dict()) - for item in optim_state_dict["param_groups"]: + + # === Collect state dicts === + if self.zero_stage >= 2: + # FSDP2: collective op, all ranks participate; rank 0 gets full state + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + model_state = get_model_state_dict(self.wrapper, options=options) + optim_state = get_optimizer_state_dict( + self.wrapper, self.optimizer, options=options + ) + elif self.zero_stage == 1: + # ZeRO-1: consolidate sharded optimizer state to rank 0 + model_state = module.state_dict() + self.optimizer.consolidate_state_dict(to=0) + optim_state = ( + deepcopy(self.optimizer.state_dict()) if self.rank == 0 else {} + ) + else: + model_state = module.state_dict() + optim_state = deepcopy(self.optimizer.state_dict()) + + # === Only rank 0 writes to disk === + if self.rank != 0: + return + for item in optim_state["param_groups"]: item["lr"] = float(item["lr"]) torch.save( - {"model": module.state_dict(), "optimizer": optim_state_dict}, + {"model": model_state, "optimizer": optim_state}, save_path, ) checkpoint_dir = save_path.parent diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 8c20bb8bf4..4b04269e3f 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3273,6 +3273,23 @@ def training_args( doc_model_prob = "The visiting probability of each model for each training step in the multi-task mode." doc_data_dict = "The multiple definition of the data, used in the multi-task mode." doc_acc_freq = "Gradient accumulation steps (number of steps to accumulate gradients before performing an update)." + doc_zero_stage = ( + "ZeRO optimization stage for distributed training memory reduction. " + "0: standard DDP, lowest communication overhead but highest memory usage " + "(full optimizer states, gradients, and parameters replicated on every GPU). " + "1: DDP + ZeRO stage-1, shards optimizer states across GPUs via " + "ZeroRedundancyOptimizer; same communication volume as DDP (2x model size) " + "but reduces optimizer memory to 1/N per GPU. " + "2: FSDP2 stage-2, shards optimizer states and gradients; same communication " + "volume as stage-1 but further reduces gradient memory to 1/N per GPU. " + "Note: FSDP2 introduces DTensor dispatch overhead that can slow down " + "models with many small layers; use torch.compile to mitigate. " + "3: FSDP2 stage-3, shards parameters as well; maximum memory savings but " + "50% more communication (3x model size) due to parameter all-gather in " + "both forward and backward passes. " + "Default is 0. Requires distributed launch via torchrun. " + "Currently supports single-task training; does not support LKF or change_bias_after_training." + ) arg_training_data = training_data_args() arg_validation_data = validation_data_args() @@ -3395,6 +3412,13 @@ def training_args( default=1, doc=doc_only_pd_supported + doc_acc_freq, ), + Argument( + "zero_stage", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_zero_stage, + ), ] variants = [ Variant( diff --git a/doc/train/parallel-training.md b/doc/train/parallel-training.md index 998f1c3bec..8cc26510c3 100644 --- a/doc/train/parallel-training.md +++ b/doc/train/parallel-training.md @@ -98,6 +98,52 @@ optional arguments: Currently, parallel training in pytorch version is implemented in the form of PyTorch Distributed Data Parallelism [DDP](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). DeePMD-kit will decide whether to launch the training in parallel (distributed) mode or in serial mode depending on your execution command. +### Optional ZeRO memory optimization + +In PyTorch backend, DeePMD-kit supports ZeRO (Zero Redundancy Optimizer) stages +to reduce per-GPU memory usage during distributed training. + +| `zero_stage` | Strategy | Communication | Memory saving | +| ------------ | ----------------------------- | ------------- | --------------------------------------------- | +| 0 | Standard DDP (default) | 2Ψ | None (full replication on every GPU) | +| 1 | DDP + ZeRO Stage 1 | 2Ψ | Optimizer states / N | +| 2 | FSDP2 SHARD_GRAD_OP (Stage 2) | 2Ψ | Gradients + optimizer states / N | +| 3 | FSDP2 FULL_SHARD (Stage 3) | 3Ψ | Parameters + gradients + optimizer states / N | + +**How to choose:** + +- **Stage 0** (DDP): Lowest overhead, fastest training speed. All optimizer states, + gradients, and parameters are fully replicated on every GPU. Use this when GPU + memory is sufficient. +- **Stage 1** (DDP + ZeRO-1): Same communication pattern as DDP (AllReduce), minimal + speed impact. Shards optimizer states only, reducing optimizer memory to 1/N per GPU. + Recommended first step when DDP runs out of memory. +- **Stage 2** (FSDP2): Shards both optimizer states and gradients. Same total + communication volume as stage 1, but uses ReduceScatter + AllGather instead of + AllReduce. FSDP2 introduces DTensor dispatch overhead that can noticeably slow down + models with many small layers; consider `torch.compile` to mitigate. +- **Stage 3** (FSDP2): Maximum memory savings by also sharding parameters, but incurs + 50% more communication (3Ψ) due to parameter all-gather in both forward and backward + passes. Only use when stage 2 still runs out of memory. + +Enable it in input config: + +```json +{ + "training": { + "zero_stage": 1 + } +} +``` + +Constraints: + +- Works only in PyTorch backend. +- Requires distributed launch with `torchrun`. +- Currently single-task only. +- Not supported with `LKF` optimizer. +- `change_bias_after_training` must be `false`. + ### Dataloader and Dataset One of the major differences between two backends during training is that the PyTorch version employs a multi-threaded data loading utility [DataLoader](https://pytorch.org/docs/stable/data.html).