@@ -265,6 +265,66 @@ def _fsdp_uses_flat_param_optim(model: torch.nn.Module | None) -> bool:
265265 return not getattr (model , "_use_orig_params" , True )
266266
267267
268+ def _get_cl_param_fqns (opt_model : torch .nn .Module | None ) -> set [str ]:
269+ """Return FQNs of FSDP-managed original params recorded as channels_last.
270+
271+ For every FSDP submodule in *opt_model*, reads ``flat_param._fqns`` /
272+ ``_shapes`` / ``_strides`` / ``_contiguities`` and returns the set of
273+ original-parameter FQNs whose ``_contiguities[i] is False`` and whose
274+ recorded strides match ``channels_last`` (4-D) or ``channels_last_3d``
275+ (5-D). That is the same bit ``_get_unflat_views`` consults to decide
276+ ``view`` vs ``as_strided`` on save -- so ``_contiguities[i] is False``
277+ is exactly the signal that the destination ``FlatParameter`` slot
278+ expects NHWC storage order at load time.
279+
280+ Returns an empty set when *opt_model* isn't FSDP+``use_orig_params=False``
281+ (the only configuration where the flatten/unflatten asymmetry exists).
282+
283+ Each FQN is built as ``{module_path_to_FSDP}.{flat_param._fqns[i]}``,
284+ matching DCP's ``_get_fqns`` convention -- specifically, FSDP's
285+ ``_fsdp_wrapped_module`` segments are stripped from the path so the
286+ returned FQNs line up with the keys in ``optim_sd["state"]``.
287+
288+ The ``_orig_mod.`` (``torch.compile``) prefix is also stripped, matching
289+ the normalization ``save_checkpoint`` applies to optimizer
290+ ``param_names``.
291+ """
292+ if not _fsdp_uses_flat_param_optim (opt_model ):
293+ return set ()
294+
295+ cl_fqns : set [str ] = set ()
296+ for module_name , module in opt_model .named_modules ():
297+ if not isinstance (module , FSDP ):
298+ continue
299+ flat_param = getattr (module , "_flat_param" , None )
300+ if flat_param is None :
301+ continue
302+ # DCP's ``_get_fqns`` skips the ``_fsdp_wrapped_module`` attribute
303+ # when building parameter FQNs; mirror that by removing the segment
304+ # from the module path.
305+ path_segments = [
306+ seg
307+ for seg in module_name .split ("." )
308+ if seg and seg != "_fsdp_wrapped_module"
309+ ]
310+ prefix = "." .join (path_segments )
311+ if prefix :
312+ prefix += "."
313+ for fqn , shape , stride , contig in zip (
314+ flat_param ._fqns ,
315+ flat_param ._shapes ,
316+ flat_param ._strides ,
317+ flat_param ._contiguities ,
318+ ):
319+ if contig :
320+ continue
321+ # CL / CL3D both have channel stride == 1 (channel is the
322+ # innermost / fastest-varying dim in NHWC / NDHWC storage).
323+ if len (shape ) in (4 , 5 ) and len (stride ) == len (shape ) and stride [1 ] == 1 :
324+ cl_fqns .add ((prefix + fqn ).removeprefix ("_orig_mod." ))
325+ return cl_fqns
326+
327+
268328def _remap_channels_last_optim_sd (
269329 opt_model : torch .nn .Module | None ,
270330 optim_sd : dict [str , Any ],
@@ -282,24 +342,27 @@ def _remap_channels_last_optim_sd(
282342 For a 4-D Conv2d weight in ``channels_last`` format the two orders
283343 differ, so the round-trip silently corrupts the optimizer state.
284344
285- Detect channels_last entries directly on *optim_sd* (the saved tensor
286- preserves its memory format through ``torch.save`` / ``torch.load``)
287- and pre-permute them so the loader's ``torch.flatten`` produces the
288- same byte sequence the ``FlatParameter`` was originally filled with.
289- Other entries (and ranks that received an empty ``optim_sd`` for the
290- broadcast-from-rank-0 path) pass through unchanged.
345+ The remap is gated on the **destination** ``FlatParameter`` slot's
346+ expected byte order (via ``flat_param._contiguities``), not on the
347+ saved tensor's layout. That's the only signal that always matches what
348+ the load-side ``_flatten_tensor_optim_state`` will do, so it works for
349+ every save/load layout combination -- in particular for
350+ ``FSDP+ShardTensor`` configurations where ``distribute_module`` calls
351+ ``.contiguous()`` and silently strips channels_last before FSDP wraps,
352+ making saved tensors standard-contig even though the conceptual model
353+ has CL conv weights.
354+
355+ Inputs are also normalized to standard contiguity before the layout
356+ decision: a CL tensor that *isn't* getting permuted (because the
357+ destination is non-CL) would otherwise survive into DCP's per-tensor
358+ ``dist.broadcast`` and hit the same layout-blind broadcast bug
359+ ``_force_standard_contiguous`` fixes for model state.
291360
292361 Only fires when *opt_model* is FSDP-wrapped with
293362 ``use_orig_params=False`` -- with ``use_orig_params=True`` the
294363 asymmetry doesn't exist and the remap would *cause* the corruption it
295364 is meant to prevent.
296365
297- Note: we cannot inspect the live model to find channels_last params --
298- with ``use_orig_params=False`` the original parameters are hidden
299- behind plain tensor attributes and ``named_parameters()`` only sees
300- the 1-D ``FlatParameter``. So detection is on the saved tensors
301- instead.
302-
303366 See ``torch/distributed/fsdp/_optim_utils.py::_flatten_tensor_optim_state``
304367 and ``_flat_param.py::flatten_tensors``.
305368 """
@@ -308,12 +371,21 @@ def _remap_channels_last_optim_sd(
308371 if not _fsdp_uses_flat_param_optim (opt_model ):
309372 return optim_sd
310373
311- def _maybe_remap (t : torch .Tensor ) -> torch .Tensor :
312- if isinstance (t , DTensor ):
374+ cl_fqns = _get_cl_param_fqns (opt_model )
375+
376+ def _normalize (t : torch .Tensor , is_cl_dest : bool ) -> torch .Tensor :
377+ if isinstance (t , DTensor ) or t .dim () == 0 :
313378 return t
314- if t .dim () == 4 and t .is_contiguous (memory_format = torch .channels_last ):
379+ # Force standard contiguity first so any saved-CL bytes are
380+ # rewritten in NCHW storage order before the layout decision; this
381+ # makes the subsequent broadcast inside DCP layout-safe whether or
382+ # not we permute.
383+ t = t .contiguous ()
384+ if not is_cl_dest :
385+ return t
386+ if t .dim () == 4 :
315387 return t .permute (0 , 2 , 3 , 1 ).contiguous ().view (* t .shape )
316- if t .dim () == 5 and t . is_contiguous ( memory_format = torch . channels_last_3d ) :
388+ if t .dim () == 5 :
317389 return t .permute (0 , 2 , 3 , 4 , 1 ).contiguous ().view (* t .shape )
318390 return t
319391
@@ -322,9 +394,10 @@ def _maybe_remap(t: torch.Tensor) -> torch.Tensor:
322394 if not isinstance (pstate , dict ):
323395 new_state [pname ] = pstate
324396 continue
397+ is_cl_dest = pname .removeprefix ("_orig_mod." ) in cl_fqns
325398 new_ps : dict [str , Any ] = {}
326399 for k , v in pstate .items ():
327- new_ps [k ] = _maybe_remap ( v ) if isinstance (v , torch .Tensor ) else v
400+ new_ps [k ] = _normalize ( v , is_cl_dest ) if isinstance (v , torch .Tensor ) else v
328401 new_state [pname ] = new_ps
329402
330403 return {** optim_sd , "state" : new_state }
@@ -1155,16 +1228,32 @@ def _load_checkpoint_distributed(
11551228 path , index = epoch , model_type = "pt" , distributed = True
11561229 )
11571230
1231+ # Broadcast file existence so all ranks agree on whether to enter the
1232+ # (collective) optimizer load. Without this, a rundir that has model
1233+ # weights but no training checkpoint -- e.g. fine-tuning from a
1234+ # weights-only export -- would have rank 0 enter ``set_optimizer_state_dict``
1235+ # with an empty dict and trip the "missing 'state'" error inside DCP.
1236+ ckpt_exists = fs .exists (checkpoint_filename ) if is_rank0 else None
1237+ ckpt_flags : list [Any ] = [ckpt_exists ]
1238+ torch .distributed .broadcast_object_list (ckpt_flags , src = 0 )
1239+ ckpt_exists = ckpt_flags [0 ]
1240+
1241+ if not ckpt_exists :
1242+ checkpoint_logging .warning (
1243+ f"No training checkpoint at { checkpoint_filename } ; "
1244+ "skipping optimizer/scheduler/scaler load"
1245+ )
1246+ return 0
1247+
11581248 checkpoint_dict : dict [str , Any ] = {}
11591249 if is_rank0 :
1160- if fs .exists (checkpoint_filename ):
1161- file_to_load = _cache_if_needed (checkpoint_filename )
1162- checkpoint_dict = torch .load (
1163- file_to_load , map_location = device , weights_only = False
1164- )
1165- checkpoint_logging .success (
1166- f"Loaded checkpoint file { checkpoint_filename } to device { device } "
1167- )
1250+ file_to_load = _cache_if_needed (checkpoint_filename )
1251+ checkpoint_dict = torch .load (
1252+ file_to_load , map_location = device , weights_only = False
1253+ )
1254+ checkpoint_logging .success (
1255+ f"Loaded checkpoint file { checkpoint_filename } to device { device } "
1256+ )
11681257
11691258 # Optimizer state via DCP (collective)
11701259 if optimizer :
0 commit comments