Skip to content

Commit 15df86f

Browse files
committed
Update storage-related xFails.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent 078b35a commit 15df86f

2 files changed

Lines changed: 13 additions & 21 deletions

File tree

tests/pytorch/distributed/test_torch_fsdp2.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,9 @@ def _run_allgather_test(sharding_dims, recipe):
133133
)
134134
def test_fp8_fsdp2_allgather(sharding_dims, fp_recipe):
135135
"""Verify FSDP2 FP8 all-gather matches a manual dequantize-then-gather reference."""
136-
if fp_recipe in ("Float8BlockScaling", "NVFP4BlockScaling"):
136+
if fp_recipe == "NVFP4BlockScaling":
137137
pytest.xfail(
138-
f"{fp_recipe}: block-scaled quantization formats are not supported by the "
139-
"FP8 FSDP2 all-gather correctness test."
138+
f"{fp_recipe}: NVFP4 FSDP2 all-gather hooks need to be implemented."
140139
)
141140

142141
parallel_size = math.prod(x for x in sharding_dims if x != 0)
@@ -181,20 +180,7 @@ def test_fsdp2_fused_adam_fp8_master_weights(fp_recipe):
181180

182181
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
183182
def test_fsdp2_fused_adam_fp8_master_weights_no_meta(fp_recipe):
184-
"""FusedAdam(master_weights=True) + FSDP2 + quantized_model_init (CUDA init, no meta device).
185-
186-
Block-scaling QuantizedTensors (MXFP8, Float8Blockwise, NVFP4) are wrapper
187-
subclasses with data_ptr() == 0. Without meta-device init, FSDP2's
188-
reset_sharded_param() crashes with 'invalid python storage'.
189-
Per-tensor FP8 (DelayedScaling, Float8CurrentScaling) works because
190-
Float8Tensor's storage is accessible.
191-
"""
192-
if fp_recipe in ("MXFP8BlockScaling", "Float8BlockScaling", "NVFP4BlockScaling"):
193-
pytest.xfail(
194-
f"{fp_recipe}: FSDP2 without meta-device init crashes on block-scaling "
195-
"QuantizedTensor wrapper subclasses (data_ptr() == 0). "
196-
"Use device='meta' + reset_parameters() after sharding."
197-
)
183+
"""FusedAdam(master_weights=True) + FSDP2 + quantized_model_init (CUDA init, no meta device)."""
198184
_run_fused_adam_test("fused_adam_fp8_master_weights_no_meta", fp_recipe)
199185

200186

@@ -232,8 +218,8 @@ def test_fsdp2_dcp_output_parity(fp_recipe):
232218

233219
if fp_recipe == "NVFP4BlockScaling":
234220
pytest.xfail(
235-
"NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() "
236-
"which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage"
221+
"NVFP4BlockScaling: Failing parity tests with DCP. Snippet: \n"
222+
"Fresh model loaded from DCP checkpoint produces different output."
237223
)
238224

239225
if fp_recipe == "Float8BlockScaling" and torch.cuda.get_device_capability()[0] == 12:
@@ -261,8 +247,8 @@ def test_fsdp2_dcp_output_parity_async(fp_recipe):
261247

262248
if fp_recipe == "NVFP4BlockScaling":
263249
pytest.xfail(
264-
"NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() "
265-
"which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage"
250+
"NVFP4BlockScaling: Failing parity tests with DCP. Snippet: \n"
251+
"Fresh model loaded from DCP checkpoint produces different output."
266252
)
267253

268254
if fp_recipe == "Float8BlockScaling":

transformer_engine/pytorch/tensor/float8_blockwise_tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Any, Optional, Tuple, Union
1111

1212
import torch
13+
from torch.distributed.tensor import DTensor
1314

1415
import transformer_engine_torch as tex
1516
from transformer_engine_torch import DType as TE_DType
@@ -726,6 +727,11 @@ def fsdp_post_all_gather(
726727
# columnwise_data is (K, full_M), logical shape is (full_M, K)
727728
data_shape = (columnwise_data.shape[1], columnwise_data.shape[0])
728729

730+
if isinstance(out, DTensor):
731+
# out.to_local() is not supported with Torch Dispatch,
732+
# for quantized tensors with _transpose usage.
733+
out = out._local_tensor
734+
729735
if out is not None:
730736
# Update existing tensor in-place (subsequent iterations)
731737
out._rowwise_data = rowwise_data

0 commit comments

Comments
 (0)