@@ -100,6 +100,10 @@ def __init__(
100100 self ._async_ckpt = AsyncCheckpointer (mode = config .async_mode )
101101 self ._process_group = process_group
102102 self ._pp_rank = pp_rank
103+ # Dataloader state stashed during load() when the caller cannot yet
104+ # provide a dataloader object. Applied later via
105+ # apply_dataloader_state() once the loader is constructed.
106+ self ._pending_dataloader_state : dict [str , Any ] | None = None
103107
104108 def _checkpoint_dir (self , step : int ) -> Path :
105109 return self .base_dir / f"step_{ step } "
@@ -170,6 +174,13 @@ def save(
170174 # Cleanup old checkpoints
171175 self ._cleanup ()
172176
177+ # save() is a collective: non-rank-0 ranks must not return until
178+ # rank-0 has committed train_state.pt, metadata.json, and the
179+ # latest symlink. Without this barrier, post-save hooks or readers
180+ # on other ranks race rank-0's writes (especially on NFS/Lustre).
181+ if dist .is_initialized ():
182+ dist .barrier ()
183+
173184 def wait (self ) -> None :
174185 """Block until any pending async checkpoint save completes."""
175186 self ._async_ckpt .wait ()
@@ -218,18 +229,46 @@ def load(
218229 if "optimizer" in dcp_state :
219230 self .optimizer .load_state_dict (dcp_state ["optimizer" ])
220231
221- # Load non-distributed state
232+ # Load non-distributed state. On NFS/Lustre, independent stat()
233+ # calls can disagree briefly across ranks; if some ranks enter
234+ # this branch and others don't, the broadcast_object_list below
235+ # hangs. Use a rank-0-authoritative existence check broadcast to
236+ # all ranks so every rank takes the same branch.
222237 train_state_path = ckpt_dir / _TRAIN_STATE_FILE
223- if train_state_path .exists ():
224- train_state = _load_train_state (train_state_path )
238+ if dist .is_initialized ():
239+ exists_flag = [train_state_path .exists () if self ._rank == 0 else False ]
240+ dist .broadcast_object_list (exists_flag , src = 0 )
241+ train_state_exists = bool (exists_flag [0 ])
242+ else :
243+ train_state_exists = train_state_path .exists ()
244+
245+ if train_state_exists :
246+ # Rank-0-authoritative: only rank 0 reads the file. The
247+ # ownership check inside ``_load_train_state`` runs there and
248+ # the resulting state is broadcast to all ranks below. Other
249+ # ranks pass ``None`` into the broadcast.
250+ train_state = (
251+ _load_train_state (train_state_path )
252+ if self ._rank == 0 or not dist .is_initialized ()
253+ else None
254+ )
225255
226- # Broadcast from rank 0 to all ranks
256+ # Broadcast from rank 0 to all ranks. PyTorch 2.11's
257+ # broadcast_object_list does not accept async_op, so a per-op
258+ # timeout cannot be wired here — this call inherits the 1800s
259+ # process-group default. A wedged rank will still surface, just
260+ # later than the other fast-fail paths in this patch.
227261 if dist .is_initialized ():
228262 object_list = [train_state if self ._rank == 0 else None ]
229263 dist .broadcast_object_list (object_list , src = 0 )
230264 train_state = object_list [0 ]
231265
232266 assert train_state is not None , "train_state broadcast failed"
267+ # Stash dataloader state if the caller can't yet provide the loader
268+ # object. Training loops construct the dataloader after load() so
269+ # apply_dataloader_state() can restore it once it exists.
270+ if dataloader is None and "dataloader" in train_state :
271+ self ._pending_dataloader_state = train_state ["dataloader" ]
233272 step , tokens_seen , extra = restore_train_state (
234273 train_state ,
235274 scheduler = scheduler ,
@@ -240,6 +279,25 @@ def load(
240279
241280 return 0 , 0 , {}
242281
282+ def apply_dataloader_state (self , dataloader : Any ) -> None :
283+ """Apply any dataloader state stashed during load().
284+
285+ Training loops call load() before constructing the dataloader (since
286+ the dataloader depends on phase/annealing state that load() restores).
287+ This method applies the stashed state once the loader exists.
288+
289+ No-op if no state is pending, or if the loader does not support
290+ ``load_state_dict`` (e.g., plain torch DataLoader for HF streaming).
291+ """
292+ if self ._pending_dataloader_state is None :
293+ return
294+ if dataloader is None or not hasattr (dataloader , "load_state_dict" ):
295+ self ._pending_dataloader_state = None
296+ return
297+ dataloader .load_state_dict (self ._pending_dataloader_state )
298+ self ._pending_dataloader_state = None
299+ logger .info ("Applied stashed dataloader state" )
300+
243301 def _resolve_load_path (self , path : str | None = None ) -> Path | None :
244302 """Resolve the checkpoint path to load from."""
245303 if path is not None :
0 commit comments