@@ -221,3 +221,91 @@ def save_final_model_fsdp2(
221221 save_file (model_state_dict , os .path .join (save_directory , "model.safetensors" ))
222222 config .to_json_file (os .path .join (save_directory , "config.json" ))
223223 logger .info (f"Saved final FSDP2 model to { save_directory } " )
224+
225+
226+ # ============================================================================
227+ # DDP Checkpointing
228+ # ============================================================================
229+
230+
231+ def load_checkpoint_ddp (
232+ model : torch .nn .Module ,
233+ optimizer : torch .optim .Optimizer ,
234+ scheduler : torch .optim .lr_scheduler .LRScheduler ,
235+ ckpt_path : str | os .PathLike ,
236+ dist_config : DistributedConfig ,
237+ ) -> CheckpointOutput :
238+ """Load DDP checkpoint."""
239+ checkpoint_path , _ = get_latest_checkpoint (ckpt_path )
240+ if not checkpoint_path :
241+ logger .info ("No DDP checkpoint found, starting from scratch" )
242+ return CheckpointOutput (model , optimizer , scheduler , 0 , 0 )
243+
244+ checkpoint = torch .load (
245+ checkpoint_path / "checkpoint.pt" ,
246+ map_location = f"cuda:{ dist_config .local_rank } " ,
247+ weights_only = True ,
248+ )
249+
250+ model .load_state_dict (checkpoint ["model" ], strict = False )
251+ optimizer .load_state_dict (checkpoint ["optimizer" ])
252+ scheduler .load_state_dict (checkpoint ["scheduler" ])
253+
254+ if dist_config .is_main_process ():
255+ logger .info (f"Loaded DDP checkpoint from step { checkpoint ['step' ]} " )
256+
257+ # Increment the step by one to avoid re-running the previous step.
258+ return CheckpointOutput (model , optimizer , scheduler , checkpoint ["step" ] + 1 , checkpoint ["epoch" ])
259+
260+
261+ def save_checkpoint_ddp (
262+ model : torch .nn .Module ,
263+ optimizer : torch .optim .Optimizer ,
264+ scheduler : torch .optim .lr_scheduler .LRScheduler ,
265+ ckpt_path : str | os .PathLike ,
266+ step : int ,
267+ epoch : int ,
268+ dist_config : DistributedConfig ,
269+ max_checkpoints : int | None = None ,
270+ ) -> None :
271+ """Save DDP checkpoint (rank-0 only since the model is replicated)."""
272+ if not dist_config .is_main_process ():
273+ return
274+
275+ ckpt_path = Path (ckpt_path )
276+ checkpoint_path = ckpt_path / f"step_{ step } "
277+ checkpoint_path .mkdir (parents = True , exist_ok = True )
278+
279+ torch .save (
280+ {
281+ "model" : model .state_dict (),
282+ "optimizer" : optimizer .state_dict (),
283+ "scheduler" : scheduler .state_dict (),
284+ "step" : step ,
285+ "epoch" : epoch ,
286+ },
287+ checkpoint_path / "checkpoint.pt" ,
288+ )
289+ logger .info (f"Saved DDP checkpoint to { checkpoint_path } " )
290+
291+ if max_checkpoints is not None :
292+ prune_checkpoints (ckpt_path , max_checkpoints )
293+
294+
295+ def save_final_model_ddp (
296+ model : torch .nn .Module ,
297+ config ,
298+ save_directory : str | os .PathLike ,
299+ dist_config : DistributedConfig ,
300+ ) -> None :
301+ """Save final model for DDP - only on main process."""
302+ if not dist_config .is_main_process ():
303+ return
304+
305+ # Unwrap DDP if wrapped.
306+ underlying_model = model .module if hasattr (model , "module" ) else model
307+
308+ os .makedirs (save_directory , exist_ok = True )
309+ save_file (underlying_model .state_dict (), os .path .join (save_directory , "model.safetensors" ))
310+ config .to_json_file (os .path .join (save_directory , "config.json" ))
311+ logger .info (f"Saved final DDP model to { save_directory } " )
0 commit comments