Skip to content

Commit 8a9201e

Browse files
committed
Optimizer state load improvement
1 parent ec3b375 commit 8a9201e

4 files changed

Lines changed: 155 additions & 29 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
173173
`smooth_laplacian` and `compute_quality_metrics` have been replaced
174174
with the dtype-aware `.clamp(min=safe_eps(dtype))` to avoid silently
175175
zeroing fp16 weights.
176-
- Fixed a silent bug in loading of optimizer state from checkpoint for
176+
- Fixed a silent bug in loading state from checkpoint for
177177
FSDP-backed models with `use_orig_params=False` and channels last
178178
memory format.
179179
- Fixed issues with physicsnemo.nn.functional's `radius_search` that

examples/weather/stormcast/test_training.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,13 @@ def test_checkpoint_integrity(
292292
(params0, opt_params0) = get_state_dict(net0, opt0, options=options)
293293
(params1, opt_params1) = get_state_dict(net1, opt1, options=options)
294294

295+
assert set(params0.keys()) == set(params1.keys()), (
296+
"State dicts before and after checkpointing have different keys"
297+
)
298+
assert set(opt_params0.keys()) == set(opt_params1.keys()), (
299+
"Optimizer state dicts before and after checkpointing have different keys"
300+
)
301+
295302
for key, param0 in params0.items():
296303
param1 = params1[key]
297304
assert (param0 == param1).all().cpu().item(), (
@@ -305,6 +312,38 @@ def test_checkpoint_integrity(
305312
f"Optimizer parameter {key} before and after checkpointing is not equal"
306313
)
307314

315+
for _ in range(5):
316+
t1.train_step()
317+
t1.save_checkpoint()
318+
319+
torch.distributed.barrier()
320+
321+
# flip sharding setting to test that sharded checkpoints load ok in non-sharded mode and vice versa
322+
cfg_diffusion.training.force_sharding = not cfg_diffusion.training.force_sharding
323+
t2 = trainer.Trainer(cfg_diffusion.copy())
324+
net2 = t2.net
325+
opt2 = t2.optimizer
326+
327+
options = StateDictOptions(full_state_dict=True)
328+
(params1, opt_params1) = get_state_dict(net1, opt1, options=options)
329+
(params2, opt_params2) = get_state_dict(net2, opt2, options=options)
330+
331+
assert set(params1.keys()) == set(params2.keys())
332+
assert set(opt_params1.keys()) == set(opt_params2.keys())
333+
334+
for key, param1 in params1.items():
335+
param2 = params2[key]
336+
assert (param1 == param2).all().cpu().item(), (
337+
f"Model parameter {key} before (force_sharding={force_sharding}) and after force_sharding={not force_sharding} checkpointing is not equal"
338+
)
339+
340+
for key, opt_param1 in opt_params1["state"].items():
341+
opt_param2 = opt_params2["state"][key]
342+
for opt_var in opt_param1:
343+
assert (opt_param1[opt_var] == opt_param2[opt_var]).all().cpu().item(), (
344+
f"Optimizer parameter {key} before (force_sharding={force_sharding}) and after force_sharding={not force_sharding} checkpointing is not equal"
345+
)
346+
308347
if dist.world_size != 4:
309348
return # remaining tests are for the 4-GPU setup
310349

physicsnemo/utils/checkpoint.py

Lines changed: 114 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,66 @@ def _fsdp_uses_flat_param_optim(model: torch.nn.Module | None) -> bool:
265265
return not getattr(model, "_use_orig_params", True)
266266

267267

268+
def _get_cl_param_fqns(opt_model: torch.nn.Module | None) -> set[str]:
269+
"""Return FQNs of FSDP-managed original params recorded as channels_last.
270+
271+
For every FSDP submodule in *opt_model*, reads ``flat_param._fqns`` /
272+
``_shapes`` / ``_strides`` / ``_contiguities`` and returns the set of
273+
original-parameter FQNs whose ``_contiguities[i] is False`` and whose
274+
recorded strides match ``channels_last`` (4-D) or ``channels_last_3d``
275+
(5-D). That is the same bit ``_get_unflat_views`` consults to decide
276+
``view`` vs ``as_strided`` on save -- so ``_contiguities[i] is False``
277+
is exactly the signal that the destination ``FlatParameter`` slot
278+
expects NHWC storage order at load time.
279+
280+
Returns an empty set when *opt_model* isn't FSDP+``use_orig_params=False``
281+
(the only configuration where the flatten/unflatten asymmetry exists).
282+
283+
Each FQN is built as ``{module_path_to_FSDP}.{flat_param._fqns[i]}``,
284+
matching DCP's ``_get_fqns`` convention -- specifically, FSDP's
285+
``_fsdp_wrapped_module`` segments are stripped from the path so the
286+
returned FQNs line up with the keys in ``optim_sd["state"]``.
287+
288+
The ``_orig_mod.`` (``torch.compile``) prefix is also stripped, matching
289+
the normalization ``save_checkpoint`` applies to optimizer
290+
``param_names``.
291+
"""
292+
if not _fsdp_uses_flat_param_optim(opt_model):
293+
return set()
294+
295+
cl_fqns: set[str] = set()
296+
for module_name, module in opt_model.named_modules():
297+
if not isinstance(module, FSDP):
298+
continue
299+
flat_param = getattr(module, "_flat_param", None)
300+
if flat_param is None:
301+
continue
302+
# DCP's ``_get_fqns`` skips the ``_fsdp_wrapped_module`` attribute
303+
# when building parameter FQNs; mirror that by removing the segment
304+
# from the module path.
305+
path_segments = [
306+
seg
307+
for seg in module_name.split(".")
308+
if seg and seg != "_fsdp_wrapped_module"
309+
]
310+
prefix = ".".join(path_segments)
311+
if prefix:
312+
prefix += "."
313+
for fqn, shape, stride, contig in zip(
314+
flat_param._fqns,
315+
flat_param._shapes,
316+
flat_param._strides,
317+
flat_param._contiguities,
318+
):
319+
if contig:
320+
continue
321+
# CL / CL3D both have channel stride == 1 (channel is the
322+
# innermost / fastest-varying dim in NHWC / NDHWC storage).
323+
if len(shape) in (4, 5) and len(stride) == len(shape) and stride[1] == 1:
324+
cl_fqns.add((prefix + fqn).removeprefix("_orig_mod."))
325+
return cl_fqns
326+
327+
268328
def _remap_channels_last_optim_sd(
269329
opt_model: torch.nn.Module | None,
270330
optim_sd: dict[str, Any],
@@ -282,24 +342,27 @@ def _remap_channels_last_optim_sd(
282342
For a 4-D Conv2d weight in ``channels_last`` format the two orders
283343
differ, so the round-trip silently corrupts the optimizer state.
284344
285-
Detect channels_last entries directly on *optim_sd* (the saved tensor
286-
preserves its memory format through ``torch.save`` / ``torch.load``)
287-
and pre-permute them so the loader's ``torch.flatten`` produces the
288-
same byte sequence the ``FlatParameter`` was originally filled with.
289-
Other entries (and ranks that received an empty ``optim_sd`` for the
290-
broadcast-from-rank-0 path) pass through unchanged.
345+
The remap is gated on the **destination** ``FlatParameter`` slot's
346+
expected byte order (via ``flat_param._contiguities``), not on the
347+
saved tensor's layout. That's the only signal that always matches what
348+
the load-side ``_flatten_tensor_optim_state`` will do, so it works for
349+
every save/load layout combination -- in particular for
350+
``FSDP+ShardTensor`` configurations where ``distribute_module`` calls
351+
``.contiguous()`` and silently strips channels_last before FSDP wraps,
352+
making saved tensors standard-contig even though the conceptual model
353+
has CL conv weights.
354+
355+
Inputs are also normalized to standard contiguity before the layout
356+
decision: a CL tensor that *isn't* getting permuted (because the
357+
destination is non-CL) would otherwise survive into DCP's per-tensor
358+
``dist.broadcast`` and hit the same layout-blind broadcast bug
359+
``_force_standard_contiguous`` fixes for model state.
291360
292361
Only fires when *opt_model* is FSDP-wrapped with
293362
``use_orig_params=False`` -- with ``use_orig_params=True`` the
294363
asymmetry doesn't exist and the remap would *cause* the corruption it
295364
is meant to prevent.
296365
297-
Note: we cannot inspect the live model to find channels_last params --
298-
with ``use_orig_params=False`` the original parameters are hidden
299-
behind plain tensor attributes and ``named_parameters()`` only sees
300-
the 1-D ``FlatParameter``. So detection is on the saved tensors
301-
instead.
302-
303366
See ``torch/distributed/fsdp/_optim_utils.py::_flatten_tensor_optim_state``
304367
and ``_flat_param.py::flatten_tensors``.
305368
"""
@@ -308,12 +371,21 @@ def _remap_channels_last_optim_sd(
308371
if not _fsdp_uses_flat_param_optim(opt_model):
309372
return optim_sd
310373

311-
def _maybe_remap(t: torch.Tensor) -> torch.Tensor:
312-
if isinstance(t, DTensor):
374+
cl_fqns = _get_cl_param_fqns(opt_model)
375+
376+
def _normalize(t: torch.Tensor, is_cl_dest: bool) -> torch.Tensor:
377+
if isinstance(t, DTensor) or t.dim() == 0:
313378
return t
314-
if t.dim() == 4 and t.is_contiguous(memory_format=torch.channels_last):
379+
# Force standard contiguity first so any saved-CL bytes are
380+
# rewritten in NCHW storage order before the layout decision; this
381+
# makes the subsequent broadcast inside DCP layout-safe whether or
382+
# not we permute.
383+
t = t.contiguous()
384+
if not is_cl_dest:
385+
return t
386+
if t.dim() == 4:
315387
return t.permute(0, 2, 3, 1).contiguous().view(*t.shape)
316-
if t.dim() == 5 and t.is_contiguous(memory_format=torch.channels_last_3d):
388+
if t.dim() == 5:
317389
return t.permute(0, 2, 3, 4, 1).contiguous().view(*t.shape)
318390
return t
319391

@@ -322,9 +394,10 @@ def _maybe_remap(t: torch.Tensor) -> torch.Tensor:
322394
if not isinstance(pstate, dict):
323395
new_state[pname] = pstate
324396
continue
397+
is_cl_dest = pname.removeprefix("_orig_mod.") in cl_fqns
325398
new_ps: dict[str, Any] = {}
326399
for k, v in pstate.items():
327-
new_ps[k] = _maybe_remap(v) if isinstance(v, torch.Tensor) else v
400+
new_ps[k] = _normalize(v, is_cl_dest) if isinstance(v, torch.Tensor) else v
328401
new_state[pname] = new_ps
329402

330403
return {**optim_sd, "state": new_state}
@@ -1155,16 +1228,32 @@ def _load_checkpoint_distributed(
11551228
path, index=epoch, model_type="pt", distributed=True
11561229
)
11571230

1231+
# Broadcast file existence so all ranks agree on whether to enter the
1232+
# (collective) optimizer load. Without this, a rundir that has model
1233+
# weights but no training checkpoint -- e.g. fine-tuning from a
1234+
# weights-only export -- would have rank 0 enter ``set_optimizer_state_dict``
1235+
# with an empty dict and trip the "missing 'state'" error inside DCP.
1236+
ckpt_exists = fs.exists(checkpoint_filename) if is_rank0 else None
1237+
ckpt_flags: list[Any] = [ckpt_exists]
1238+
torch.distributed.broadcast_object_list(ckpt_flags, src=0)
1239+
ckpt_exists = ckpt_flags[0]
1240+
1241+
if not ckpt_exists:
1242+
checkpoint_logging.warning(
1243+
f"No training checkpoint at {checkpoint_filename}; "
1244+
"skipping optimizer/scheduler/scaler load"
1245+
)
1246+
return 0
1247+
11581248
checkpoint_dict: dict[str, Any] = {}
11591249
if is_rank0:
1160-
if fs.exists(checkpoint_filename):
1161-
file_to_load = _cache_if_needed(checkpoint_filename)
1162-
checkpoint_dict = torch.load(
1163-
file_to_load, map_location=device, weights_only=False
1164-
)
1165-
checkpoint_logging.success(
1166-
f"Loaded checkpoint file {checkpoint_filename} to device {device}"
1167-
)
1250+
file_to_load = _cache_if_needed(checkpoint_filename)
1251+
checkpoint_dict = torch.load(
1252+
file_to_load, map_location=device, weights_only=False
1253+
)
1254+
checkpoint_logging.success(
1255+
f"Loaded checkpoint file {checkpoint_filename} to device {device}"
1256+
)
11681257

11691258
# Optimizer state via DCP (collective)
11701259
if optimizer:

test/utils/test_checkpoint_distributed.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,7 @@ def test_cross_mode_channels_last_model_load(shared_tmp_dir):
410410
)
411411
if dm.rank == 0:
412412
for name, expected in saved_params.items():
413-
assert name in full_loaded_model, (
414-
f"Loaded model state missing '{name}'"
415-
)
413+
assert name in full_loaded_model, f"Loaded model state missing '{name}'"
416414
actual = full_loaded_model[name].detach().contiguous().cpu()
417415
assert torch.equal(actual, expected), (
418416
f"Logical model values for '{name}' differ between save and "

0 commit comments

Comments
 (0)