8181 import torch ._dynamo
8282
8383import torch .distributed as dist
84+ from torch .distributed .checkpoint .state_dict import (
85+ StateDictOptions ,
86+ get_model_state_dict ,
87+ get_optimizer_state_dict ,
88+ set_optimizer_state_dict ,
89+ )
90+ from torch .distributed .fsdp import (
91+ fully_shard ,
92+ )
93+ from torch .distributed .optim import (
94+ ZeroRedundancyOptimizer ,
95+ )
8496from torch .nn .parallel import DistributedDataParallel as DDP
8597from torch .utils .data import (
8698 DataLoader ,
@@ -131,14 +143,9 @@ def __init__(
131143 self .model_keys = (
132144 list (model_params ["model_dict" ]) if self .multi_task else ["Default" ]
133145 )
134- self .rank = (
135- dist .get_rank () if dist .is_available () and dist .is_initialized () else 0
136- )
137- self .world_size = (
138- dist .get_world_size ()
139- if dist .is_available () and dist .is_initialized ()
140- else 1
141- )
146+ self .is_distributed = dist .is_available () and dist .is_initialized ()
147+ self .rank = dist .get_rank () if self .is_distributed else 0
148+ self .world_size = dist .get_world_size () if self .is_distributed else 1
142149 self .num_model = len (self .model_keys )
143150
144151 # Iteration config
@@ -154,6 +161,19 @@ def __init__(
154161 self .change_bias_after_training = training_params .get (
155162 "change_bias_after_training" , False
156163 )
164+ self .zero_stage = int (training_params .get ("zero_stage" , 0 ))
165+ if self .zero_stage not in (0 , 1 , 2 , 3 ):
166+ raise ValueError (
167+ f"training.zero_stage must be 0, 1, 2, or 3, got { self .zero_stage } "
168+ )
169+ if self .zero_stage > 0 and not self .is_distributed :
170+ raise ValueError (
171+ "training.zero_stage requires distributed launch via torchrun."
172+ )
173+ if self .zero_stage > 0 and self .change_bias_after_training :
174+ raise ValueError (
175+ "training.zero_stage does not support change_bias_after_training."
176+ )
157177 self .lcurve_should_print_header = True
158178
159179 def get_opt_param (params : dict [str , Any ]) -> tuple [str , dict [str , Any ]]:
@@ -300,6 +320,12 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
300320 )
301321 else :
302322 self .opt_type , self .opt_param = get_opt_param (training_params )
323+ if self .zero_stage > 0 and self .multi_task :
324+ raise ValueError (
325+ "training.zero_stage is currently only supported in single-task training."
326+ )
327+ if self .zero_stage > 0 and self .opt_type == "LKF" :
328+ raise ValueError ("training.zero_stage does not support LKF optimizer." )
303329
304330 # loss_param_tmp for Hessian activation
305331 loss_param_tmp = None
@@ -690,15 +716,25 @@ def single_model_finetune(
690716 data_stat_protect = _data_stat_protect [0 ],
691717 )
692718
693- if dist . is_available () and dist . is_initialized () :
719+ if self . is_distributed :
694720 torch .cuda .set_device (LOCAL_RANK )
695- # DDP will guarantee the model parameters are identical across all processes
696- self .wrapper = DDP (
697- self .wrapper ,
698- device_ids = [LOCAL_RANK ],
699- find_unused_parameters = True ,
700- output_device = LOCAL_RANK ,
701- )
721+ if self .zero_stage >= 2 :
722+ # FSDP2 does NOT broadcast params (unlike DDP constructor).
723+ # Ensure all ranks share identical weights before sharding.
724+ for p in self .wrapper .parameters ():
725+ dist .broadcast (p .data , src = 0 )
726+ for b in self .wrapper .buffers ():
727+ dist .broadcast (b .data , src = 0 )
728+ reshard = self .zero_stage >= 3
729+ self .wrapper = fully_shard (self .wrapper , reshard_after_forward = reshard )
730+ else :
731+ # zero_stage=0 or 1: standard DDP (ZeRO-1 will wrap the optimizer)
732+ self .wrapper = DDP (
733+ self .wrapper ,
734+ device_ids = [LOCAL_RANK ],
735+ find_unused_parameters = True ,
736+ output_device = LOCAL_RANK ,
737+ )
702738
703739 # TODO add lr warmups for multitask
704740 # author: iProzd
@@ -714,20 +750,19 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
714750 # author: iProzd
715751 if self .opt_type in ["Adam" , "AdamW" ]:
716752 if self .opt_type == "Adam" :
717- self .optimizer = torch . optim . Adam (
718- self . wrapper . parameters () ,
753+ self .optimizer = self . _create_optimizer (
754+ torch . optim . Adam ,
719755 lr = self .lr_exp .start_lr ,
720- fused = False if DEVICE .type == "cpu" else True ,
756+ fused = DEVICE .type != "cpu" ,
721757 )
722758 else :
723- self .optimizer = torch . optim . AdamW (
724- self . wrapper . parameters () ,
759+ self .optimizer = self . _create_optimizer (
760+ torch . optim . AdamW ,
725761 lr = self .lr_exp .start_lr ,
726762 weight_decay = float (self .opt_param ["weight_decay" ]),
727- fused = False if DEVICE .type == "cpu" else True ,
763+ fused = DEVICE .type != "cpu" ,
728764 )
729- if optimizer_state_dict is not None and self .restart_training :
730- self .optimizer .load_state_dict (optimizer_state_dict )
765+ self ._load_optimizer_state (optimizer_state_dict )
731766 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
732767 self .optimizer ,
733768 lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
@@ -737,8 +772,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
737772 self .wrapper .parameters (), 0.98 , 0.99870 , self .opt_param ["kf_blocksize" ]
738773 )
739774 elif self .opt_type == "AdaMuon" :
740- self .optimizer = AdaMuonOptimizer (
741- self . wrapper . parameters () ,
775+ self .optimizer = self . _create_optimizer (
776+ AdaMuonOptimizer ,
742777 lr = self .lr_exp .start_lr ,
743778 momentum = float (self .opt_param ["momentum" ]),
744779 weight_decay = float (self .opt_param ["weight_decay" ]),
@@ -750,8 +785,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
750785 lr_adjust_coeff = float (self .opt_param ["lr_adjust_coeff" ]),
751786 )
752787 elif self .opt_type == "HybridMuon" :
753- self .optimizer = HybridMuonOptimizer (
754- self . wrapper . parameters () ,
788+ self .optimizer = self . _create_optimizer (
789+ HybridMuonOptimizer ,
755790 lr = self .lr_exp .start_lr ,
756791 momentum = float (self .opt_param ["momentum" ]),
757792 weight_decay = float (self .opt_param ["weight_decay" ]),
@@ -764,15 +799,25 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
764799 muon_2d_only = bool (self .opt_param ["muon_2d_only" ]),
765800 min_2d_dim = int (self .opt_param ["min_2d_dim" ]),
766801 )
767- if optimizer_state_dict is not None and self .restart_training :
768- self .optimizer .load_state_dict (optimizer_state_dict )
802+ self ._load_optimizer_state (optimizer_state_dict )
769803 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
770804 self .optimizer ,
771805 lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
772806 )
773807 else :
774808 raise ValueError (f"Not supported optimizer type '{ self .opt_type } '" )
775809
810+ if self .zero_stage > 0 and self .rank == 0 :
811+ if self .zero_stage == 1 :
812+ log .info ("Enabled DDP + ZeRO Stage-1 Optimizer State Sharding." )
813+ else :
814+ stage = (
815+ "FULL_SHARD (Stage 3)"
816+ if self .zero_stage >= 3
817+ else "SHARD_GRAD_OP (Stage 2)"
818+ )
819+ log .info (f"Enabled FSDP2 { stage } ." )
820+
776821 # Tensorboard
777822 self .enable_tensorboard = training_params .get ("tensorboard" , False )
778823 self .tensorboard_log_dir = training_params .get ("tensorboard_log_dir" , "log" )
@@ -822,6 +867,58 @@ def _log_parameter_count(self) -> None:
822867 f"Model Params [{ model_key } ]: { total / 1e6 :.3f} M (Trainable: { trainable / 1e6 :.3f} M)"
823868 )
824869
870+ def _create_optimizer (
871+ self ,
872+ optimizer_class : type [torch .optim .Optimizer ],
873+ ** kwargs : Any ,
874+ ) -> torch .optim .Optimizer :
875+ """
876+ Construct optimizer, wrapping with ZeroRedundancyOptimizer when zero_stage=1.
877+
878+ Parameters
879+ ----------
880+ optimizer_class : type[torch.optim.Optimizer]
881+ The optimizer class to instantiate.
882+ **kwargs : Any
883+ Keyword arguments forwarded to the optimizer constructor.
884+
885+ Returns
886+ -------
887+ torch.optim.Optimizer
888+ Constructed optimizer instance.
889+ """
890+ if self .zero_stage == 1 :
891+ return ZeroRedundancyOptimizer (
892+ self .wrapper .parameters (),
893+ optimizer_class = optimizer_class ,
894+ ** kwargs ,
895+ )
896+ return optimizer_class (self .wrapper .parameters (), ** kwargs )
897+
898+ def _get_inner_module (self ) -> ModelWrapper :
899+ """Unwrap DDP if needed. FSDP2 is in-place so no unwrapping required."""
900+ if self .is_distributed and self .zero_stage <= 1 :
901+ return self .wrapper .module
902+ return self .wrapper
903+
904+ def _load_optimizer_state (
905+ self , optimizer_state_dict : dict [str , Any ] | None
906+ ) -> None :
907+ """Load optimizer state for restart training when available."""
908+ if optimizer_state_dict is None or not self .restart_training :
909+ return
910+ if self .zero_stage >= 2 :
911+ set_optimizer_state_dict (
912+ self .wrapper ,
913+ self .optimizer ,
914+ optim_state_dict = optimizer_state_dict ,
915+ options = StateDictOptions (
916+ full_state_dict = True , broadcast_from_rank0 = True
917+ ),
918+ )
919+ else :
920+ self .optimizer .load_state_dict (optimizer_state_dict )
921+
825922 def run (self ) -> None :
826923 fout = (
827924 open (
@@ -892,12 +989,30 @@ def step(_step_id: int, task_key: str = "Default") -> None:
892989 )
893990 loss .backward ()
894991 if self .gradient_max_norm > 0.0 :
895- torch .nn .utils .clip_grad_norm_ (
992+ # FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead.
993+ total_norm = torch .nn .utils .clip_grad_norm_ (
896994 self .wrapper .parameters (),
897995 self .gradient_max_norm ,
898- error_if_nonfinite = True ,
899996 )
900- with torch .device ("cpu" ):
997+ if not torch .isfinite (total_norm ):
998+ bad_params = []
999+ for name , p in self .wrapper .named_parameters ():
1000+ if p .grad is not None :
1001+ grad_norm = p .grad .data .norm ()
1002+ if not torch .isfinite (grad_norm ):
1003+ bad_params .append (
1004+ f" { name } : grad_norm={ grad_norm } , shape={ list (p .shape )} "
1005+ )
1006+ detail = (
1007+ "\n " .join (bad_params )
1008+ if bad_params
1009+ else " (all individual grads finite, overflow in norm reduction)"
1010+ )
1011+ raise RuntimeError (
1012+ f"Non-finite gradient norm: { total_norm } \n "
1013+ f"Parameters with non-finite gradients:\n { detail } "
1014+ )
1015+ with torch .device (DEVICE ):
9011016 self .optimizer .step ()
9021017 self .scheduler .step ()
9031018 elif self .opt_type == "LKF" :
@@ -1205,20 +1320,15 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
12051320 and _step_id != self .start_step
12061321 )
12071322 or (display_step_id ) == self .num_steps
1208- ) and (self .rank == 0 or dist .get_rank () == 0 ):
1323+ ) and (self .zero_stage > 0 or self . rank == 0 or dist .get_rank () == 0 ):
12091324 # Handle the case if rank 0 aborted and re-assigned
12101325 self .latest_model = Path (self .save_ckpt + f"-{ display_step_id } .pt" )
1211-
1212- module = (
1213- self .wrapper .module
1214- if dist .is_available () and dist .is_initialized ()
1215- else self .wrapper
1216- )
12171326 self .save_model (self .latest_model , lr = cur_lr , step = _step_id )
1218- log .info (f"Saved model to { self .latest_model } " )
1219- symlink_prefix_files (self .latest_model .stem , self .save_ckpt )
1220- with open ("checkpoint" , "w" ) as f :
1221- f .write (str (self .latest_model ))
1327+ if self .rank == 0 or dist .get_rank () == 0 :
1328+ log .info (f"Saved model to { self .latest_model } " )
1329+ symlink_prefix_files (self .latest_model .stem , self .save_ckpt )
1330+ with open ("checkpoint" , "w" ) as f :
1331+ f .write (str (self .latest_model ))
12221332
12231333 # tensorboard
12241334 if self .enable_tensorboard and (
@@ -1273,13 +1383,19 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
12731383 with open ("checkpoint" , "w" ) as f :
12741384 f .write (str (self .latest_model ))
12751385
1386+ if self .num_steps == 0 and self .zero_stage > 0 :
1387+ # ZeRO-1 / FSDP: all ranks participate in save_model (collective op)
1388+ self .latest_model = Path (self .save_ckpt + "-0.pt" )
1389+ self .save_model (self .latest_model , lr = 0 , step = 0 )
1390+
12761391 if (
12771392 self .rank == 0 or dist .get_rank () == 0
12781393 ): # Handle the case if rank 0 aborted and re-assigned
12791394 if self .num_steps == 0 :
1280- # when num_steps is 0, the checkpoint is never not saved
1281- self .latest_model = Path (self .save_ckpt + "-0.pt" )
1282- self .save_model (self .latest_model , lr = 0 , step = 0 )
1395+ if self .zero_stage == 0 :
1396+ # When num_steps is 0, the checkpoint is never saved in the loop
1397+ self .latest_model = Path (self .save_ckpt + "-0.pt" )
1398+ self .save_model (self .latest_model , lr = 0 , step = 0 )
12831399 log .info (f"Saved model to { self .latest_model } " )
12841400 symlink_prefix_files (self .latest_model .stem , self .save_ckpt )
12851401 with open ("checkpoint" , "w" ) as f :
@@ -1321,18 +1437,36 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13211437 )
13221438
13231439 def save_model (self , save_path : str , lr : float = 0.0 , step : int = 0 ) -> None :
1324- module = (
1325- self .wrapper .module
1326- if dist .is_available () and dist .is_initialized ()
1327- else self .wrapper
1328- )
1440+ module = self ._get_inner_module ()
13291441 module .train_infos ["lr" ] = float (lr )
13301442 module .train_infos ["step" ] = step
1331- optim_state_dict = deepcopy (self .optimizer .state_dict ())
1332- for item in optim_state_dict ["param_groups" ]:
1443+
1444+ # === Collect state dicts ===
1445+ if self .zero_stage >= 2 :
1446+ # FSDP2: collective op, all ranks participate; rank 0 gets full state
1447+ options = StateDictOptions (full_state_dict = True , cpu_offload = True )
1448+ model_state = get_model_state_dict (self .wrapper , options = options )
1449+ optim_state = get_optimizer_state_dict (
1450+ self .wrapper , self .optimizer , options = options
1451+ )
1452+ elif self .zero_stage == 1 :
1453+ # ZeRO-1: consolidate sharded optimizer state to rank 0
1454+ model_state = module .state_dict ()
1455+ self .optimizer .consolidate_state_dict (to = 0 )
1456+ optim_state = (
1457+ deepcopy (self .optimizer .state_dict ()) if self .rank == 0 else {}
1458+ )
1459+ else :
1460+ model_state = module .state_dict ()
1461+ optim_state = deepcopy (self .optimizer .state_dict ())
1462+
1463+ # === Only rank 0 writes to disk ===
1464+ if self .rank != 0 :
1465+ return
1466+ for item in optim_state ["param_groups" ]:
13331467 item ["lr" ] = float (item ["lr" ])
13341468 torch .save (
1335- {"model" : module . state_dict () , "optimizer" : optim_state_dict },
1469+ {"model" : model_state , "optimizer" : optim_state },
13361470 save_path ,
13371471 )
13381472 checkpoint_dir = save_path .parent
0 commit comments