8181 import torch ._dynamo
8282
8383import torch .distributed as dist
84+ from torch .distributed .checkpoint .state_dict import (
85+ StateDictOptions ,
86+ get_model_state_dict ,
87+ get_optimizer_state_dict ,
88+ set_optimizer_state_dict ,
89+ )
90+ from torch .distributed .fsdp import (
91+ fully_shard ,
92+ )
93+ from torch .distributed .optim import (
94+ ZeroRedundancyOptimizer ,
95+ )
8496from torch .nn .parallel import DistributedDataParallel as DDP
8597from torch .utils .data import (
8698 DataLoader ,
@@ -131,14 +143,9 @@ def __init__(
131143 self .model_keys = (
132144 list (model_params ["model_dict" ]) if self .multi_task else ["Default" ]
133145 )
134- self .rank = (
135- dist .get_rank () if dist .is_available () and dist .is_initialized () else 0
136- )
137- self .world_size = (
138- dist .get_world_size ()
139- if dist .is_available () and dist .is_initialized ()
140- else 1
141- )
146+ self .is_distributed = dist .is_available () and dist .is_initialized ()
147+ self .rank = dist .get_rank () if self .is_distributed else 0
148+ self .world_size = dist .get_world_size () if self .is_distributed else 1
142149 self .num_model = len (self .model_keys )
143150
144151 # Iteration config
@@ -154,6 +161,15 @@ def __init__(
154161 self .change_bias_after_training = training_params .get (
155162 "change_bias_after_training" , False
156163 )
164+ self .zero_stage = int (training_params .get ("zero_stage" , 0 ))
165+ if self .zero_stage > 0 and not self .is_distributed :
166+ raise ValueError (
167+ "training.zero_stage requires distributed launch via torchrun."
168+ )
169+ if self .zero_stage > 0 and self .change_bias_after_training :
170+ raise ValueError (
171+ "training.zero_stage does not support change_bias_after_training."
172+ )
157173 self .lcurve_should_print_header = True
158174
159175 def get_opt_param (params : dict [str , Any ]) -> tuple [str , dict [str , Any ]]:
@@ -300,6 +316,12 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
300316 )
301317 else :
302318 self .opt_type , self .opt_param = get_opt_param (training_params )
319+ if self .zero_stage > 0 and self .multi_task :
320+ raise ValueError (
321+ "training.zero_stage is currently only supported in single-task training."
322+ )
323+ if self .zero_stage > 0 and self .opt_type == "LKF" :
324+ raise ValueError ("training.zero_stage does not support LKF optimizer." )
303325
304326 # loss_param_tmp for Hessian activation
305327 loss_param_tmp = None
@@ -690,15 +712,25 @@ def single_model_finetune(
690712 data_stat_protect = _data_stat_protect [0 ],
691713 )
692714
693- if dist . is_available () and dist . is_initialized () :
715+ if self . is_distributed :
694716 torch .cuda .set_device (LOCAL_RANK )
695- # DDP will guarantee the model parameters are identical across all processes
696- self .wrapper = DDP (
697- self .wrapper ,
698- device_ids = [LOCAL_RANK ],
699- find_unused_parameters = True ,
700- output_device = LOCAL_RANK ,
701- )
717+ if self .zero_stage >= 2 :
718+ # FSDP2 does NOT broadcast params (unlike DDP constructor).
719+ # Ensure all ranks share identical weights before sharding.
720+ for p in self .wrapper .parameters ():
721+ dist .broadcast (p .data , src = 0 )
722+ for b in self .wrapper .buffers ():
723+ dist .broadcast (b .data , src = 0 )
724+ reshard = self .zero_stage >= 3
725+ fully_shard (self .wrapper , reshard_after_forward = reshard )
726+ else :
727+ # zero_stage=0 or 1: standard DDP (ZeRO-1 will wrap the optimizer)
728+ self .wrapper = DDP (
729+ self .wrapper ,
730+ device_ids = [LOCAL_RANK ],
731+ find_unused_parameters = self .multi_task ,
732+ output_device = LOCAL_RANK ,
733+ )
702734
703735 # TODO add lr warmups for multitask
704736 # author: iProzd
@@ -714,20 +746,19 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
714746 # author: iProzd
715747 if self .opt_type in ["Adam" , "AdamW" ]:
716748 if self .opt_type == "Adam" :
717- self .optimizer = torch . optim . Adam (
718- self . wrapper . parameters () ,
749+ self .optimizer = self . _create_optimizer (
750+ torch . optim . Adam ,
719751 lr = self .lr_exp .start_lr ,
720- fused = False if DEVICE .type == "cpu" else True ,
752+ fused = DEVICE .type != "cpu" ,
721753 )
722754 else :
723- self .optimizer = torch . optim . AdamW (
724- self . wrapper . parameters () ,
755+ self .optimizer = self . _create_optimizer (
756+ torch . optim . AdamW ,
725757 lr = self .lr_exp .start_lr ,
726758 weight_decay = float (self .opt_param ["weight_decay" ]),
727- fused = False if DEVICE .type == "cpu" else True ,
759+ fused = DEVICE .type != "cpu" ,
728760 )
729- if optimizer_state_dict is not None and self .restart_training :
730- self .optimizer .load_state_dict (optimizer_state_dict )
761+ self ._load_optimizer_state (optimizer_state_dict )
731762 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
732763 self .optimizer ,
733764 lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
@@ -737,8 +768,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
737768 self .wrapper .parameters (), 0.98 , 0.99870 , self .opt_param ["kf_blocksize" ]
738769 )
739770 elif self .opt_type == "AdaMuon" :
740- self .optimizer = AdaMuonOptimizer (
741- self . wrapper . parameters () ,
771+ self .optimizer = self . _create_optimizer (
772+ AdaMuonOptimizer ,
742773 lr = self .lr_exp .start_lr ,
743774 momentum = float (self .opt_param ["momentum" ]),
744775 weight_decay = float (self .opt_param ["weight_decay" ]),
@@ -750,8 +781,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
750781 lr_adjust_coeff = float (self .opt_param ["lr_adjust_coeff" ]),
751782 )
752783 elif self .opt_type == "HybridMuon" :
753- self .optimizer = HybridMuonOptimizer (
754- self . wrapper . parameters () ,
784+ self .optimizer = self . _create_optimizer (
785+ HybridMuonOptimizer ,
755786 lr = self .lr_exp .start_lr ,
756787 momentum = float (self .opt_param ["momentum" ]),
757788 weight_decay = float (self .opt_param ["weight_decay" ]),
@@ -764,15 +795,25 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
764795 muon_2d_only = bool (self .opt_param ["muon_2d_only" ]),
765796 min_2d_dim = int (self .opt_param ["min_2d_dim" ]),
766797 )
767- if optimizer_state_dict is not None and self .restart_training :
768- self .optimizer .load_state_dict (optimizer_state_dict )
798+ self ._load_optimizer_state (optimizer_state_dict )
769799 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
770800 self .optimizer ,
771801 lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
772802 )
773803 else :
774804 raise ValueError (f"Not supported optimizer type '{ self .opt_type } '" )
775805
806+ if self .zero_stage > 0 and self .rank == 0 :
807+ if self .zero_stage == 1 :
808+ log .info ("Enabled DDP + ZeRO Stage-1 Optimizer State Sharding." )
809+ else :
810+ stage = (
811+ "FULL_SHARD (Stage 3)"
812+ if self .zero_stage >= 3
813+ else "SHARD_GRAD_OP (Stage 2)"
814+ )
815+ log .info (f"Enabled FSDP2 { stage } ." )
816+
776817 # Tensorboard
777818 self .enable_tensorboard = training_params .get ("tensorboard" , False )
778819 self .tensorboard_log_dir = training_params .get ("tensorboard_log_dir" , "log" )
@@ -822,6 +863,58 @@ def _log_parameter_count(self) -> None:
822863 f"Model Params [{ model_key } ]: { total / 1e6 :.3f} M (Trainable: { trainable / 1e6 :.3f} M)"
823864 )
824865
866+ def _create_optimizer (
867+ self ,
868+ optimizer_class : type [torch .optim .Optimizer ],
869+ ** kwargs : Any ,
870+ ) -> torch .optim .Optimizer :
871+ """
872+ Construct optimizer, wrapping with ZeroRedundancyOptimizer when zero_stage=1.
873+
874+ Parameters
875+ ----------
876+ optimizer_class : type[torch.optim.Optimizer]
877+ The optimizer class to instantiate.
878+ **kwargs : Any
879+ Keyword arguments forwarded to the optimizer constructor.
880+
881+ Returns
882+ -------
883+ torch.optim.Optimizer
884+ Constructed optimizer instance.
885+ """
886+ if self .zero_stage == 1 :
887+ return ZeroRedundancyOptimizer (
888+ self .wrapper .parameters (),
889+ optimizer_class = optimizer_class ,
890+ ** kwargs ,
891+ )
892+ return optimizer_class (self .wrapper .parameters (), ** kwargs )
893+
894+ def _get_inner_module (self ) -> ModelWrapper :
895+ """Unwrap DDP if needed. FSDP2 is in-place so no unwrapping required."""
896+ if self .is_distributed and self .zero_stage <= 1 :
897+ return self .wrapper .module
898+ return self .wrapper
899+
900+ def _load_optimizer_state (
901+ self , optimizer_state_dict : dict [str , Any ] | None
902+ ) -> None :
903+ """Load optimizer state for restart training when available."""
904+ if optimizer_state_dict is None or not self .restart_training :
905+ return
906+ if self .zero_stage >= 2 :
907+ set_optimizer_state_dict (
908+ self .wrapper ,
909+ self .optimizer ,
910+ optim_state_dict = optimizer_state_dict ,
911+ options = StateDictOptions (
912+ full_state_dict = True , broadcast_from_rank0 = True
913+ ),
914+ )
915+ else :
916+ self .optimizer .load_state_dict (optimizer_state_dict )
917+
825918 def run (self ) -> None :
826919 fout = (
827920 open (
@@ -892,12 +985,16 @@ def step(_step_id: int, task_key: str = "Default") -> None:
892985 )
893986 loss .backward ()
894987 if self .gradient_max_norm > 0.0 :
895- torch .nn .utils .clip_grad_norm_ (
988+ # Avoid error_if_nonfinite=True: FSDP2 sharded
989+ # DTensor gradients may not support it. Manual
990+ # isfinite check achieves the same fail-fast behavior.
991+ total_norm = torch .nn .utils .clip_grad_norm_ (
896992 self .wrapper .parameters (),
897993 self .gradient_max_norm ,
898- error_if_nonfinite = True ,
899994 )
900- with torch .device ("cpu" ):
995+ if not torch .isfinite (total_norm ):
996+ raise RuntimeError (f"Non-finite gradient norm: { total_norm } " )
997+ with torch .device (DEVICE ):
901998 self .optimizer .step ()
902999 self .scheduler .step ()
9031000 elif self .opt_type == "LKF" :
@@ -1205,20 +1302,15 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
12051302 and _step_id != self .start_step
12061303 )
12071304 or (display_step_id ) == self .num_steps
1208- ) and (self .rank == 0 or dist .get_rank () == 0 ):
1305+ ) and (self .zero_stage > 0 or self . rank == 0 or dist .get_rank () == 0 ):
12091306 # Handle the case if rank 0 aborted and re-assigned
12101307 self .latest_model = Path (self .save_ckpt + f"-{ display_step_id } .pt" )
1211-
1212- module = (
1213- self .wrapper .module
1214- if dist .is_available () and dist .is_initialized ()
1215- else self .wrapper
1216- )
12171308 self .save_model (self .latest_model , lr = cur_lr , step = _step_id )
1218- log .info (f"Saved model to { self .latest_model } " )
1219- symlink_prefix_files (self .latest_model .stem , self .save_ckpt )
1220- with open ("checkpoint" , "w" ) as f :
1221- f .write (str (self .latest_model ))
1309+ if self .rank == 0 or dist .get_rank () == 0 :
1310+ log .info (f"Saved model to { self .latest_model } " )
1311+ symlink_prefix_files (self .latest_model .stem , self .save_ckpt )
1312+ with open ("checkpoint" , "w" ) as f :
1313+ f .write (str (self .latest_model ))
12221314
12231315 # tensorboard
12241316 if self .enable_tensorboard and (
@@ -1273,13 +1365,19 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
12731365 with open ("checkpoint" , "w" ) as f :
12741366 f .write (str (self .latest_model ))
12751367
1368+ if self .num_steps == 0 and self .zero_stage > 0 :
1369+ # ZeRO-1 / FSDP: all ranks participate in save_model (collective op)
1370+ self .latest_model = Path (self .save_ckpt + "-0.pt" )
1371+ self .save_model (self .latest_model , lr = 0 , step = 0 )
1372+
12761373 if (
12771374 self .rank == 0 or dist .get_rank () == 0
12781375 ): # Handle the case if rank 0 aborted and re-assigned
12791376 if self .num_steps == 0 :
1280- # when num_steps is 0, the checkpoint is never not saved
1281- self .latest_model = Path (self .save_ckpt + "-0.pt" )
1282- self .save_model (self .latest_model , lr = 0 , step = 0 )
1377+ if self .zero_stage == 0 :
1378+ # When num_steps is 0, the checkpoint is never saved in the loop
1379+ self .latest_model = Path (self .save_ckpt + "-0.pt" )
1380+ self .save_model (self .latest_model , lr = 0 , step = 0 )
12831381 log .info (f"Saved model to { self .latest_model } " )
12841382 symlink_prefix_files (self .latest_model .stem , self .save_ckpt )
12851383 with open ("checkpoint" , "w" ) as f :
@@ -1321,18 +1419,36 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13211419 )
13221420
13231421 def save_model (self , save_path : str , lr : float = 0.0 , step : int = 0 ) -> None :
1324- module = (
1325- self .wrapper .module
1326- if dist .is_available () and dist .is_initialized ()
1327- else self .wrapper
1328- )
1422+ module = self ._get_inner_module ()
13291423 module .train_infos ["lr" ] = float (lr )
13301424 module .train_infos ["step" ] = step
1331- optim_state_dict = deepcopy (self .optimizer .state_dict ())
1332- for item in optim_state_dict ["param_groups" ]:
1425+
1426+ # === Collect state dicts ===
1427+ if self .zero_stage >= 2 :
1428+ # FSDP2: collective op, all ranks participate; rank 0 gets full state
1429+ options = StateDictOptions (full_state_dict = True , cpu_offload = True )
1430+ model_state = get_model_state_dict (self .wrapper , options = options )
1431+ optim_state = get_optimizer_state_dict (
1432+ self .wrapper , self .optimizer , options = options
1433+ )
1434+ elif self .zero_stage == 1 :
1435+ # ZeRO-1: consolidate sharded optimizer state to rank 0
1436+ model_state = module .state_dict ()
1437+ self .optimizer .consolidate_state_dict (to = 0 )
1438+ optim_state = (
1439+ deepcopy (self .optimizer .state_dict ()) if self .rank == 0 else {}
1440+ )
1441+ else :
1442+ model_state = module .state_dict ()
1443+ optim_state = deepcopy (self .optimizer .state_dict ())
1444+
1445+ # === Only rank 0 writes to disk ===
1446+ if self .rank != 0 :
1447+ return
1448+ for item in optim_state ["param_groups" ]:
13331449 item ["lr" ] = float (item ["lr" ])
13341450 torch .save (
1335- {"model" : module . state_dict () , "optimizer" : optim_state_dict },
1451+ {"model" : model_state , "optimizer" : optim_state },
13361452 save_path ,
13371453 )
13381454 checkpoint_dir = save_path .parent
0 commit comments