@@ -163,6 +163,263 @@ def test_fsdp_checkpoint_roundtrip(
163163 assert scheduler2 .last_epoch == scheduler .last_epoch
164164
165165
166+ # ---------------------------------------------------------------------------
167+ # Plain FSDP + channels_last (regression for cross-rank layout mismatch)
168+ # ---------------------------------------------------------------------------
169+
170+
171+ class _ConvNet (nn .Module ):
172+ """Tiny conv net so the parameter set includes a 4-D weight."""
173+
174+ def __init__ (self , in_ch : int = 4 , out_ch : int = 8 , k : int = 3 ):
175+ super ().__init__ ()
176+ self .conv = nn .Conv2d (in_ch , out_ch , kernel_size = k , padding = k // 2 , bias = True )
177+ self .gn = nn .GroupNorm (num_groups = 4 , num_channels = out_ch )
178+
179+ def forward (self , x ):
180+ return self .gn (self .conv (x ))
181+
182+
183+ def _all_ranks_bit_exact (t : torch .Tensor ) -> bool :
184+ """True iff every rank holds element-wise bit-identical values for *t*."""
185+ t_min = t .detach ().clone ().float ()
186+ t_max = t .detach ().clone ().float ()
187+ dist .all_reduce (t_min , op = dist .ReduceOp .MIN )
188+ dist .all_reduce (t_max , op = dist .ReduceOp .MAX )
189+ return torch .equal (t_min , t_max )
190+
191+
192+ @pytest .mark .timeout (30 )
193+ @pytest .mark .multigpu_static
194+ @pytest .mark .parametrize ("use_orig_params" , [True , False ])
195+ @pytest .mark .parametrize (
196+ "sharding_strategy" ,
197+ [ShardingStrategy .NO_SHARD ],
198+ )
199+ def test_fsdp_checkpoint_channels_last_roundtrip (
200+ shared_tmp_dir , use_orig_params , sharding_strategy
201+ ):
202+ """Round-trip an FSDP+channels_last conv model and assert per-rank parity.
203+
204+ Regression for a layout-mismatch bug in DCP's broadcast_from_rank0 path:
205+ ``dist.broadcast`` accepts a channels_last sender (``is_contiguous`` check
206+ passes for that format) but transfers bytes in storage order, while
207+ receivers allocate ``torch.empty(shape, dtype, device)`` (standard NCHW),
208+ so 4-D conv weights were silently permuted on non-rank-0. The fix
209+ (``_force_standard_contiguous`` on rank 0 before ``set_model_state_dict``)
210+ keeps sender and receiver layouts consistent.
211+
212+ Asserts bit-exact agreement across ranks on the live FlatParameter (for
213+ ``use_orig_params=False``) / each original parameter (for True). Output
214+ equivalence isn't sufficient — a permuted conv weight preserves abs-sum
215+ and the model can stagger toward similar outputs over noise — so we
216+ check the parameter values directly.
217+
218+ The optimizer state is intentionally *not* asserted here. The optim load
219+ path is layout-correct (verified by the standalone smoketest and by
220+ running this test in isolation), but suite-level state pollution (NCCL
221+ allreduce ordering across many prior tests) accumulates FP noise in the
222+ pre-load training step, which then survives the load and makes a tight
223+ cross-rank check flaky. The existing ``test_fsdp_checkpoint_roundtrip``
224+ already covers the optim path with a tolerance-based output comparison
225+ that's robust to that noise.
226+ """
227+ dm = DistributedManager ()
228+ if dm .world_size < 2 :
229+ pytest .skip ("Need at least 2 ranks" )
230+
231+ device = dm .device
232+ mesh = init_device_mesh ("cuda" , (dm .world_size ,), mesh_dim_names = ("world" ,))
233+
234+ # Build, move to channels_last (only conv weights are affected), wrap.
235+ torch .manual_seed (0 )
236+ model = _ConvNet ().to (device = device , memory_format = torch .channels_last )
237+ fsdp_model = FSDP (
238+ model ,
239+ device_mesh = mesh ["world" ],
240+ sharding_strategy = sharding_strategy ,
241+ use_orig_params = use_orig_params ,
242+ sync_module_states = True ,
243+ )
244+ optimizer = torch .optim .Adam (fsdp_model .parameters (), lr = 1e-3 )
245+
246+ # Same x on every rank (we want the source-of-truth state to be identical
247+ # across ranks pre-save, so any post-load divergence is checkpoint-induced).
248+ x = torch .randn (2 , 4 , 8 , 8 , device = device ).contiguous (
249+ memory_format = torch .channels_last
250+ )
251+ for _ in range (2 ):
252+ fsdp_model (x ).sum ().backward ()
253+ optimizer .step ()
254+ optimizer .zero_grad ()
255+
256+ save_checkpoint (
257+ shared_tmp_dir ,
258+ models = fsdp_model ,
259+ optimizer = optimizer ,
260+ epoch = 2 ,
261+ optimizer_model = fsdp_model ,
262+ )
263+ dist .barrier ()
264+
265+ # Build a *differently-seeded* fresh model so sync_module_states alone can't
266+ # mask the bug by leaving rank 0's pre-load values on every rank.
267+ torch .manual_seed (dm .rank + 1234 )
268+ model2 = _ConvNet ().to (device = device , memory_format = torch .channels_last )
269+ fsdp_model2 = FSDP (
270+ model2 ,
271+ device_mesh = mesh ["world" ],
272+ sharding_strategy = sharding_strategy ,
273+ use_orig_params = use_orig_params ,
274+ sync_module_states = True ,
275+ )
276+ optimizer2 = torch .optim .Adam (fsdp_model2 .parameters (), lr = 1e-3 )
277+ # Step once so optimizer state is shaped before the load.
278+ fsdp_model2 (x ).sum ().backward ()
279+ optimizer2 .step ()
280+ optimizer2 .zero_grad ()
281+
282+ epoch = load_checkpoint (
283+ shared_tmp_dir ,
284+ models = fsdp_model2 ,
285+ optimizer = optimizer2 ,
286+ optimizer_model = fsdp_model2 ,
287+ )
288+ assert epoch == 2
289+
290+ # --- Cross-rank parity checks ------------------------------------------
291+ # FlatParameter (use_orig_params=False) or each original param (True).
292+ if use_orig_params :
293+ for name , p in fsdp_model2 .named_parameters ():
294+ assert _all_ranks_bit_exact (p ), (
295+ f"Parameter '{ name } ' (shape={ tuple (p .shape )} ) differs across "
296+ f"ranks after channels_last+FSDP load"
297+ )
298+ else :
299+ flat_param = fsdp_model2 ._flat_param
300+ assert _all_ranks_bit_exact (flat_param ), (
301+ "FlatParameter differs across ranks after channels_last+FSDP load"
302+ )
303+
304+ # Optimizer state cross-rank check intentionally omitted -- see docstring.
305+
306+
307+ # ---------------------------------------------------------------------------
308+ # Cross-mode load: 1-proc non-distributed save → N-proc FSDP load (with CL)
309+ # ---------------------------------------------------------------------------
310+
311+
312+ @pytest .mark .timeout (30 )
313+ @pytest .mark .multigpu_static
314+ def test_cross_mode_channels_last_model_load (shared_tmp_dir ):
315+ """Save from a single (rank-0-only) non-FSDP CL model; load model state
316+ into N-proc FSDP.
317+
318+ Realistic "trained on multi-rank, fine-tuned/inspected on a single GPU,
319+ resumed multi-rank" round-trip with channels_last. Confirms that the
320+ on-disk model state produced by the non-distributed save path is loadable
321+ by the distributed FSDP load path on every rank without layout-induced
322+ corruption.
323+
324+ Model side asserts:
325+ * every rank's post-load FlatParameter is bit-exact identical, AND
326+ * rank 0's logical values match what was saved.
327+
328+ Cross-rank parity alone can be satisfied by "everyone got the same wrong
329+ values" (e.g. silent drop), so we also check against the saved snapshot.
330+
331+ Optimizer cross-mode load is *not* tested here. The non-distributed save
332+ path writes int-keyed (param-id) optim state via ``optimizer.state_dict()``,
333+ while the distributed FSDP load path expects FQN-keyed input -- DCP's
334+ ``_split_optim_state_dict`` early-returns for int keys without converting,
335+ and the downstream ``_rekey_sharded_optim_state_dict`` then crashes on
336+ ``int.unflat_param_names``. That's a separate, pre-existing limitation
337+ of cross-mode optim restore; same-mode optim restore is exercised by
338+ ``test_fsdp_checkpoint_channels_last_roundtrip`` and is what the
339+ channels_last fix is concerned with.
340+ """
341+ dm = DistributedManager ()
342+ if dm .world_size < 2 :
343+ pytest .skip ("Need at least 2 ranks" )
344+
345+ device = dm .device
346+
347+ # ===== Phase A: 1-proc save on rank 0 only =====
348+ saved_params : dict [str , torch .Tensor ] = {}
349+ if dm .rank == 0 :
350+ torch .manual_seed (0 )
351+ model_save = _ConvNet ().to (device = device , memory_format = torch .channels_last )
352+ optimizer_save = torch .optim .Adam (model_save .parameters (), lr = 1e-3 )
353+
354+ x = torch .randn (2 , 4 , 8 , 8 , device = device ).contiguous (
355+ memory_format = torch .channels_last
356+ )
357+ # Two steps so the saved weights have actually moved off init.
358+ for _ in range (2 ):
359+ model_save (x ).sum ().backward ()
360+ optimizer_save .step ()
361+ optimizer_save .zero_grad ()
362+
363+ # Snapshot. ``contiguous()`` pins a canonical layout for comparison;
364+ # the values are what matter.
365+ for name , p in model_save .named_parameters ():
366+ saved_params [name ] = p .detach ().clone ().contiguous ().cpu ()
367+
368+ # We deliberately save the optimizer state too, mirroring real-world
369+ # usage, but the load side will not consume it (see docstring).
370+ save_checkpoint (
371+ shared_tmp_dir ,
372+ models = model_save ,
373+ optimizer = optimizer_save ,
374+ epoch = 2 ,
375+ )
376+ dist .barrier ()
377+
378+ # ===== Phase B: N-proc FSDP-only load (model only) =====
379+ mesh = init_device_mesh ("cuda" , (dm .world_size ,), mesh_dim_names = ("world" ,))
380+
381+ # Different per-rank seed so sync_module_states alone can't mask anything.
382+ torch .manual_seed (dm .rank + 4242 )
383+ model_load = _ConvNet ().to (device = device , memory_format = torch .channels_last )
384+ fsdp_load = FSDP (
385+ model_load ,
386+ device_mesh = mesh ["world" ],
387+ sharding_strategy = ShardingStrategy .NO_SHARD ,
388+ use_orig_params = False ,
389+ sync_module_states = True ,
390+ )
391+
392+ # Pass optimizer=None: cross-mode optim load is a separate, pre-existing
393+ # PyTorch DCP limitation (see docstring). We're testing the model path.
394+ epoch = load_checkpoint (
395+ shared_tmp_dir ,
396+ models = fsdp_load ,
397+ )
398+ assert epoch == 2
399+
400+ # ===== Phase C.1: per-rank parity =====
401+ flat_param = fsdp_load ._flat_param
402+ assert _all_ranks_bit_exact (flat_param ), (
403+ "FlatParameter differs across ranks after cross-mode model load"
404+ )
405+
406+ # ===== Phase C.2: loaded values match saved values (rank 0) =====
407+ # Collective: gather the full model state dict on every rank.
408+ full_loaded_model = get_model_state_dict (
409+ fsdp_load , options = StateDictOptions (full_state_dict = True )
410+ )
411+ if dm .rank == 0 :
412+ for name , expected in saved_params .items ():
413+ assert name in full_loaded_model , (
414+ f"Loaded model state missing '{ name } '"
415+ )
416+ actual = full_loaded_model [name ].detach ().contiguous ().cpu ()
417+ assert torch .equal (actual , expected ), (
418+ f"Logical model values for '{ name } ' differ between save and "
419+ f"load (cross-mode)"
420+ )
421+
422+
166423# ---------------------------------------------------------------------------
167424# load_model_weights — plain FSDP
168425# ---------------------------------------------------------------------------
0 commit comments