Skip to content

Commit ec3b375

Browse files
committed
Fix model state channels_last load
1 parent cbb1bf1 commit ec3b375

2 files changed

Lines changed: 291 additions & 1 deletion

File tree

physicsnemo/utils/checkpoint.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,34 @@ def _cpu_offload_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
113113
return out
114114

115115

116+
def _force_standard_contiguous(state_dict: dict[str, Any]) -> dict[str, Any]:
117+
"""Make positive-dim, non-DTensor tensors standard-contiguous.
118+
119+
Compensates for a layout-blind broadcast inside DCP's
120+
``set_model_state_dict`` / ``set_optimizer_state_dict``: the rank-0→others
121+
transfer goes through ``dist.broadcast``, which is happy to send a
122+
``channels_last`` tensor (its contiguity check passes for that format) but
123+
moves bytes in *storage* order. Receivers on non-zero ranks allocate via
124+
``torch.empty(shape, dtype, device)`` (standard NCHW), so CL bytes land
125+
in NCHW positions and values are silently permuted on receive.
126+
127+
Forcing standard contiguity on rank 0 makes the broadcast layout-consistent.
128+
``.contiguous()`` is a no-op for tensors that are already standard-contig,
129+
a value-preserving (logical) copy for ``channels_last`` /
130+
``channels_last_3d``, and is intentionally skipped for ``DTensor`` (whose
131+
contiguity semantics differ).
132+
"""
133+
out: dict[str, Any] = {}
134+
for k, v in state_dict.items():
135+
if isinstance(v, torch.Tensor) and not isinstance(v, DTensor) and v.dim() > 0:
136+
out[k] = v.contiguous()
137+
elif isinstance(v, dict):
138+
out[k] = _force_standard_contiguous(v)
139+
else:
140+
out[k] = v
141+
return out
142+
143+
116144
def _get_dtensor_param_placements(
117145
model: torch.nn.Module,
118146
) -> dict[str, tuple[Any, tuple[Any, ...]]]:
@@ -1101,8 +1129,13 @@ def _load_checkpoint_distributed(
11011129
set_model_state_dict(model, sd, options=full_options)
11021130
else:
11031131
# FSDP-managed DTensors (FULL_SHARD/SHARD_GRAD_OP) or no
1104-
# DTensors at all — broadcast_from_rank0 handles both.
1132+
# DTensors at all — broadcast_from_rank0 handles both. Force
1133+
# standard contiguity on rank 0 first so the per-tensor
1134+
# broadcast inside DCP doesn't permute channels_last params on
1135+
# receive (see ``_force_standard_contiguous`` for the why).
11051136
sd = model_state_dicts.get(name, {}) if is_rank0 else {}
1137+
if is_rank0:
1138+
sd = _force_standard_contiguous(sd)
11061139
set_model_state_dict(model, sd, options=broadcast_options)
11071140
else:
11081141
# A mix of distributed and non-distributed models is valid

test/utils/test_checkpoint_distributed.py

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,263 @@ def test_fsdp_checkpoint_roundtrip(
163163
assert scheduler2.last_epoch == scheduler.last_epoch
164164

165165

166+
# ---------------------------------------------------------------------------
167+
# Plain FSDP + channels_last (regression for cross-rank layout mismatch)
168+
# ---------------------------------------------------------------------------
169+
170+
171+
class _ConvNet(nn.Module):
172+
"""Tiny conv net so the parameter set includes a 4-D weight."""
173+
174+
def __init__(self, in_ch: int = 4, out_ch: int = 8, k: int = 3):
175+
super().__init__()
176+
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=k // 2, bias=True)
177+
self.gn = nn.GroupNorm(num_groups=4, num_channels=out_ch)
178+
179+
def forward(self, x):
180+
return self.gn(self.conv(x))
181+
182+
183+
def _all_ranks_bit_exact(t: torch.Tensor) -> bool:
184+
"""True iff every rank holds element-wise bit-identical values for *t*."""
185+
t_min = t.detach().clone().float()
186+
t_max = t.detach().clone().float()
187+
dist.all_reduce(t_min, op=dist.ReduceOp.MIN)
188+
dist.all_reduce(t_max, op=dist.ReduceOp.MAX)
189+
return torch.equal(t_min, t_max)
190+
191+
192+
@pytest.mark.timeout(30)
193+
@pytest.mark.multigpu_static
194+
@pytest.mark.parametrize("use_orig_params", [True, False])
195+
@pytest.mark.parametrize(
196+
"sharding_strategy",
197+
[ShardingStrategy.NO_SHARD],
198+
)
199+
def test_fsdp_checkpoint_channels_last_roundtrip(
200+
shared_tmp_dir, use_orig_params, sharding_strategy
201+
):
202+
"""Round-trip an FSDP+channels_last conv model and assert per-rank parity.
203+
204+
Regression for a layout-mismatch bug in DCP's broadcast_from_rank0 path:
205+
``dist.broadcast`` accepts a channels_last sender (``is_contiguous`` check
206+
passes for that format) but transfers bytes in storage order, while
207+
receivers allocate ``torch.empty(shape, dtype, device)`` (standard NCHW),
208+
so 4-D conv weights were silently permuted on non-rank-0. The fix
209+
(``_force_standard_contiguous`` on rank 0 before ``set_model_state_dict``)
210+
keeps sender and receiver layouts consistent.
211+
212+
Asserts bit-exact agreement across ranks on the live FlatParameter (for
213+
``use_orig_params=False``) / each original parameter (for True). Output
214+
equivalence isn't sufficient — a permuted conv weight preserves abs-sum
215+
and the model can stagger toward similar outputs over noise — so we
216+
check the parameter values directly.
217+
218+
The optimizer state is intentionally *not* asserted here. The optim load
219+
path is layout-correct (verified by the standalone smoketest and by
220+
running this test in isolation), but suite-level state pollution (NCCL
221+
allreduce ordering across many prior tests) accumulates FP noise in the
222+
pre-load training step, which then survives the load and makes a tight
223+
cross-rank check flaky. The existing ``test_fsdp_checkpoint_roundtrip``
224+
already covers the optim path with a tolerance-based output comparison
225+
that's robust to that noise.
226+
"""
227+
dm = DistributedManager()
228+
if dm.world_size < 2:
229+
pytest.skip("Need at least 2 ranks")
230+
231+
device = dm.device
232+
mesh = init_device_mesh("cuda", (dm.world_size,), mesh_dim_names=("world",))
233+
234+
# Build, move to channels_last (only conv weights are affected), wrap.
235+
torch.manual_seed(0)
236+
model = _ConvNet().to(device=device, memory_format=torch.channels_last)
237+
fsdp_model = FSDP(
238+
model,
239+
device_mesh=mesh["world"],
240+
sharding_strategy=sharding_strategy,
241+
use_orig_params=use_orig_params,
242+
sync_module_states=True,
243+
)
244+
optimizer = torch.optim.Adam(fsdp_model.parameters(), lr=1e-3)
245+
246+
# Same x on every rank (we want the source-of-truth state to be identical
247+
# across ranks pre-save, so any post-load divergence is checkpoint-induced).
248+
x = torch.randn(2, 4, 8, 8, device=device).contiguous(
249+
memory_format=torch.channels_last
250+
)
251+
for _ in range(2):
252+
fsdp_model(x).sum().backward()
253+
optimizer.step()
254+
optimizer.zero_grad()
255+
256+
save_checkpoint(
257+
shared_tmp_dir,
258+
models=fsdp_model,
259+
optimizer=optimizer,
260+
epoch=2,
261+
optimizer_model=fsdp_model,
262+
)
263+
dist.barrier()
264+
265+
# Build a *differently-seeded* fresh model so sync_module_states alone can't
266+
# mask the bug by leaving rank 0's pre-load values on every rank.
267+
torch.manual_seed(dm.rank + 1234)
268+
model2 = _ConvNet().to(device=device, memory_format=torch.channels_last)
269+
fsdp_model2 = FSDP(
270+
model2,
271+
device_mesh=mesh["world"],
272+
sharding_strategy=sharding_strategy,
273+
use_orig_params=use_orig_params,
274+
sync_module_states=True,
275+
)
276+
optimizer2 = torch.optim.Adam(fsdp_model2.parameters(), lr=1e-3)
277+
# Step once so optimizer state is shaped before the load.
278+
fsdp_model2(x).sum().backward()
279+
optimizer2.step()
280+
optimizer2.zero_grad()
281+
282+
epoch = load_checkpoint(
283+
shared_tmp_dir,
284+
models=fsdp_model2,
285+
optimizer=optimizer2,
286+
optimizer_model=fsdp_model2,
287+
)
288+
assert epoch == 2
289+
290+
# --- Cross-rank parity checks ------------------------------------------
291+
# FlatParameter (use_orig_params=False) or each original param (True).
292+
if use_orig_params:
293+
for name, p in fsdp_model2.named_parameters():
294+
assert _all_ranks_bit_exact(p), (
295+
f"Parameter '{name}' (shape={tuple(p.shape)}) differs across "
296+
f"ranks after channels_last+FSDP load"
297+
)
298+
else:
299+
flat_param = fsdp_model2._flat_param
300+
assert _all_ranks_bit_exact(flat_param), (
301+
"FlatParameter differs across ranks after channels_last+FSDP load"
302+
)
303+
304+
# Optimizer state cross-rank check intentionally omitted -- see docstring.
305+
306+
307+
# ---------------------------------------------------------------------------
308+
# Cross-mode load: 1-proc non-distributed save → N-proc FSDP load (with CL)
309+
# ---------------------------------------------------------------------------
310+
311+
312+
@pytest.mark.timeout(30)
313+
@pytest.mark.multigpu_static
314+
def test_cross_mode_channels_last_model_load(shared_tmp_dir):
315+
"""Save from a single (rank-0-only) non-FSDP CL model; load model state
316+
into N-proc FSDP.
317+
318+
Realistic "trained on multi-rank, fine-tuned/inspected on a single GPU,
319+
resumed multi-rank" round-trip with channels_last. Confirms that the
320+
on-disk model state produced by the non-distributed save path is loadable
321+
by the distributed FSDP load path on every rank without layout-induced
322+
corruption.
323+
324+
Model side asserts:
325+
* every rank's post-load FlatParameter is bit-exact identical, AND
326+
* rank 0's logical values match what was saved.
327+
328+
Cross-rank parity alone can be satisfied by "everyone got the same wrong
329+
values" (e.g. silent drop), so we also check against the saved snapshot.
330+
331+
Optimizer cross-mode load is *not* tested here. The non-distributed save
332+
path writes int-keyed (param-id) optim state via ``optimizer.state_dict()``,
333+
while the distributed FSDP load path expects FQN-keyed input -- DCP's
334+
``_split_optim_state_dict`` early-returns for int keys without converting,
335+
and the downstream ``_rekey_sharded_optim_state_dict`` then crashes on
336+
``int.unflat_param_names``. That's a separate, pre-existing limitation
337+
of cross-mode optim restore; same-mode optim restore is exercised by
338+
``test_fsdp_checkpoint_channels_last_roundtrip`` and is what the
339+
channels_last fix is concerned with.
340+
"""
341+
dm = DistributedManager()
342+
if dm.world_size < 2:
343+
pytest.skip("Need at least 2 ranks")
344+
345+
device = dm.device
346+
347+
# ===== Phase A: 1-proc save on rank 0 only =====
348+
saved_params: dict[str, torch.Tensor] = {}
349+
if dm.rank == 0:
350+
torch.manual_seed(0)
351+
model_save = _ConvNet().to(device=device, memory_format=torch.channels_last)
352+
optimizer_save = torch.optim.Adam(model_save.parameters(), lr=1e-3)
353+
354+
x = torch.randn(2, 4, 8, 8, device=device).contiguous(
355+
memory_format=torch.channels_last
356+
)
357+
# Two steps so the saved weights have actually moved off init.
358+
for _ in range(2):
359+
model_save(x).sum().backward()
360+
optimizer_save.step()
361+
optimizer_save.zero_grad()
362+
363+
# Snapshot. ``contiguous()`` pins a canonical layout for comparison;
364+
# the values are what matter.
365+
for name, p in model_save.named_parameters():
366+
saved_params[name] = p.detach().clone().contiguous().cpu()
367+
368+
# We deliberately save the optimizer state too, mirroring real-world
369+
# usage, but the load side will not consume it (see docstring).
370+
save_checkpoint(
371+
shared_tmp_dir,
372+
models=model_save,
373+
optimizer=optimizer_save,
374+
epoch=2,
375+
)
376+
dist.barrier()
377+
378+
# ===== Phase B: N-proc FSDP-only load (model only) =====
379+
mesh = init_device_mesh("cuda", (dm.world_size,), mesh_dim_names=("world",))
380+
381+
# Different per-rank seed so sync_module_states alone can't mask anything.
382+
torch.manual_seed(dm.rank + 4242)
383+
model_load = _ConvNet().to(device=device, memory_format=torch.channels_last)
384+
fsdp_load = FSDP(
385+
model_load,
386+
device_mesh=mesh["world"],
387+
sharding_strategy=ShardingStrategy.NO_SHARD,
388+
use_orig_params=False,
389+
sync_module_states=True,
390+
)
391+
392+
# Pass optimizer=None: cross-mode optim load is a separate, pre-existing
393+
# PyTorch DCP limitation (see docstring). We're testing the model path.
394+
epoch = load_checkpoint(
395+
shared_tmp_dir,
396+
models=fsdp_load,
397+
)
398+
assert epoch == 2
399+
400+
# ===== Phase C.1: per-rank parity =====
401+
flat_param = fsdp_load._flat_param
402+
assert _all_ranks_bit_exact(flat_param), (
403+
"FlatParameter differs across ranks after cross-mode model load"
404+
)
405+
406+
# ===== Phase C.2: loaded values match saved values (rank 0) =====
407+
# Collective: gather the full model state dict on every rank.
408+
full_loaded_model = get_model_state_dict(
409+
fsdp_load, options=StateDictOptions(full_state_dict=True)
410+
)
411+
if dm.rank == 0:
412+
for name, expected in saved_params.items():
413+
assert name in full_loaded_model, (
414+
f"Loaded model state missing '{name}'"
415+
)
416+
actual = full_loaded_model[name].detach().contiguous().cpu()
417+
assert torch.equal(actual, expected), (
418+
f"Logical model values for '{name}' differ between save and "
419+
f"load (cross-mode)"
420+
)
421+
422+
166423
# ---------------------------------------------------------------------------
167424
# load_model_weights — plain FSDP
168425
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)