Skip to content

Commit af7362a

Browse files
committed
Add model weight modification guard to ensure DCP checkpoint correctness.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent 5dd64ea commit af7362a

4 files changed

Lines changed: 20 additions & 4 deletions

File tree

tests/pytorch/distributed/run_fsdp2_model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,22 @@ def _train(args):
548548
loss.backward()
549549
optimizer.step()
550550

551+
# Verify model weights have diverged from the original
552+
# model state after extra training steps.
553+
s_post_train = model.state_dict()
554+
for key in s1.keys() & s_post_train.keys():
555+
if key.endswith("_extra_state"):
556+
continue
557+
v1 = s1[key]
558+
if isinstance(v1, DTensor):
559+
v1 = v1.to_local()
560+
v_pt = s_post_train[key]
561+
if isinstance(v_pt, DTensor):
562+
v_pt = v_pt.to_local()
563+
assert not torch.allclose(v1, v_pt), (
564+
f"[{key}] Model weights should have changed after extra training steps"
565+
)
566+
551567
# Load the checkpoint.
552568
state_dict = {"app": AppState(model=model, optimizer=optimizer)}
553569
torch.distributed.checkpoint.load(state_dict=state_dict, checkpoint_id=str(CKPT_DIR))

tests/pytorch/distributed/test_torch_fsdp2.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,6 @@ def test_distributed(fp8_init, sharding_dims, fp_recipe, layer_type):
9898
f"Insufficient devices ({NUM_PROCS}) to test sharding configuration: {sharding_dims}"
9999
)
100100

101-
if fp_recipe in ("Float8BlockScaling", "NVFP4BlockScaling") and fp8_init:
102-
pytest.xfail(f"{fp_recipe} + fp8_init: test_fp8_fsdp2_allgather is currently failing.")
103-
104101
_run_test(fp8_init, sharding_dims, fp_recipe, layer_type)
105102

106103

transformer_engine/pytorch/tensor/float8_blockwise_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def untyped_storage(self) -> torch.UntypedStorage:
403403
data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data
404404
if data is not None:
405405
return data.untyped_storage()
406-
return torch.UntypedStorage(0, device=self.device)
406+
return self._default_storage
407407

408408
@classmethod
409409
def __torch_dispatch__(cls, func, types, args, kwargs=None):

transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
3535
_rowwise_scale_inv: Optional[torch.Tensor]
3636
_columnwise_scale_inv: Optional[torch.Tensor]
3737
_is_2D_scaled: bool
38+
# Default storage of 1 byte.
39+
_default_storage: torch.UntypedStorage
3840

3941
def __new__(
4042
cls,
@@ -61,6 +63,7 @@ def __new__(
6163
instance._rowwise_scale_inv = rowwise_scale_inv
6264
instance._columnwise_scale_inv = columnwise_scale_inv
6365
instance._is_2D_scaled = is_2D_scaled
66+
instance._default_storage = torch.UntypedStorage(1)
6467

6568
return instance
6669

0 commit comments

Comments
 (0)