Skip to content

Commit 3b1e4ce

Browse files
Merge branch 'main' into flash_attn_pad_bw_seqs
2 parents 0638d58 + 9af70a8 commit 3b1e4ce

15 files changed

Lines changed: 501 additions & 219 deletions

tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,6 @@ def test_fused_adam_fp8_master_weights_no_meta(recipe_name):
228228
"""
229229
recipe = get_recipe_from_string(recipe_name)
230230

231-
if recipe_name in ("MXFP8BlockScaling", "Float8BlockScaling", "NVFP4BlockScaling"):
232-
pytest.xfail(
233-
f"{recipe_name}: FSDP2 all-gather hooks for block-scaling QuantizedTensor "
234-
"subclasses fail when parameters are initialized on CUDA. "
235-
"Use device='meta' + reset_parameters() after sharding."
236-
)
237-
238231
world_size, device = _get_dist_info()
239232

240233
model = _build_model(fp8_init=True, recipe=recipe, use_meta_device=False)
@@ -604,12 +597,6 @@ def test_safetensors_fp32_export(recipe_name):
604597
- Saved tensor shapes match expected (unsharded) shapes
605598
"""
606599
recipe = get_recipe_from_string(recipe_name)
607-
if recipe_name == "MXFP8BlockScaling":
608-
pytest.xfail(
609-
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
610-
"MXFP8 quantized tensors, causing illegal memory access. "
611-
"Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789."
612-
)
613600

614601
from safetensors.torch import load_file, save_file
615602
from torch.distributed.checkpoint.state_dict import (
@@ -692,40 +679,14 @@ def test_dcp_output_parity(recipe_name, async_save):
692679
"""
693680
recipe = get_recipe_from_string(recipe_name)
694681

695-
if recipe_name == "MXFP8BlockScaling":
696-
pytest.xfail(
697-
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
698-
"MXFP8 quantized tensors, causing illegal memory access: "
699-
"/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh:92 in function "
700-
"multi_tensor_apply: CUDA Error: an illegal memory access was encountered. "
701-
"Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789."
702-
)
703-
704-
if recipe_name == "NVFP4BlockScaling":
705-
pytest.xfail(
706-
"NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() "
707-
"which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage"
708-
)
709-
710-
if (
711-
recipe_name == "Float8BlockScaling"
712-
and not async_save
713-
and torch.cuda.get_device_capability()[0] == 12
714-
):
682+
if recipe_name == "Float8BlockScaling" and torch.cuda.get_device_capability()[0] == 12:
715683
pytest.xfail(
716684
"Float8BlockScaling is failing on SM120 with RuntimeError: "
717685
"transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu:534 "
718686
"in function quantize_transpose_vector_blockwise: Assertion failed: pow2_scale. On "
719687
"Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, which "
720688
"requires using power of two scaling factors."
721689
)
722-
if recipe_name == "Float8BlockScaling" and async_save:
723-
pytest.xfail(
724-
"Float8BlockScaling: async DCP save/load round-trip produces different model "
725-
"outputs — quantization metadata (scales) is not correctly persisted through "
726-
"async distributed checkpointing. On SM120, additionally fails with pow2_scale "
727-
"assertion in quantize_transpose_vector_blockwise."
728-
)
729690

730691
import torch.distributed.checkpoint as dcp
731692

tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -379,20 +379,11 @@ def test_distributed(recipe_name, fp8_init, sharding_dims, layer_type):
379379
"sending only 1 tensor (scale is per-tensor metadata). Fix: concatenate MXFP8 "
380380
"data and scale_inv into a single buffer in pre_all_gather, split in post."
381381
)
382-
383382
if recipe_name == "Float8BlockScaling" and fp8_init:
384383
pytest.xfail(
385384
"Float8BlockScaling + fp8_init: scale inverse padding is not handled "
386385
"correctly during FSDP2 all-gather slice ops."
387386
)
388-
if recipe_name == "NVFP4BlockScaling" and fp8_init and layer_type == "TransformerLayer":
389-
pytest.xfail(
390-
"NVFP4BlockScaling + fp8_init + TransformerLayer: "
391-
"_check_fp8_fsdp2_allgather numerical error compounds across multiple "
392-
"linear layers in the transformer block (up to ~1e-2 max abs diff). "
393-
"LayerNormLinear passes with relaxed tolerances. "
394-
"NVFP4 + FSDP2 training is validated by run_fsdp2_fused_adam.py."
395-
)
396387
torch.manual_seed(42)
397388
torch.cuda.manual_seed(42)
398389

tests/pytorch/test_quantized_tensor.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,56 @@ def test_identity_op(
616616
torch.testing.assert_close(y_test, y_ref, **tols)
617617
torch.testing.assert_close(dx_test, dx_ref, **tols)
618618

619+
@pytest.mark.parametrize("quantization", _quantization_list)
620+
def test_cpu_dequantize(
621+
self,
622+
*,
623+
quantization: str,
624+
shape: Iterable[int] = (128, 128),
625+
dtype: torch.dtype = torch.bfloat16,
626+
) -> None:
627+
"""Dequantize on a CPU-resident QuantizedTensor."""
628+
629+
# Construct a quantized tensor on CUDA.
630+
_, x_cuda = make_reference_and_test_tensors(
631+
shape=shape,
632+
quantization=quantization,
633+
test_dtype=dtype,
634+
requires_grad=False,
635+
)
636+
assert isinstance(x_cuda, QuantizedTensor)
637+
assert x_cuda.device.type == "cuda"
638+
639+
# Reference: dequantize on CUDA, then move the dense result to CPU.
640+
ref_cpu = x_cuda.dequantize().to(device="cpu")
641+
642+
# Move the QuantizedTensor itself to CPU and dequantize there.
643+
# ``.cpu()`` routes through ``aten._to_copy.default`` so all inner
644+
# buffers (data, scales, amax) are moved to CPU.
645+
x_cpu = x_cuda.cpu()
646+
assert isinstance(x_cpu, QuantizedTensor)
647+
assert x_cpu.device.type == "cpu"
648+
for attr in (
649+
"_data",
650+
"_rowwise_data",
651+
"_columnwise_data",
652+
"_rowwise_scale_inv",
653+
"_columnwise_scale_inv",
654+
"_amax_rowwise",
655+
"_amax_columnwise",
656+
):
657+
buf = getattr(x_cpu, attr, None)
658+
if buf is not None:
659+
assert buf.device.type == "cpu", f"{attr} did not move to CPU"
660+
661+
# Dequantize the CPU tensor. Implementation may bounce through CUDA
662+
# internally, but must return a CPU tensor.
663+
y_cpu = x_cpu.dequantize()
664+
assert y_cpu.device.type == "cpu"
665+
assert y_cpu.dtype == ref_cpu.dtype
666+
assert y_cpu.shape == ref_cpu.shape
667+
torch.testing.assert_close(y_cpu, ref_cpu, rtol=0, atol=0)
668+
619669
@pytest.mark.parametrize("quantization", _quantization_list)
620670
@pytest.mark.parametrize("dim", [0, 1])
621671
def test_chunk(

transformer_engine/common/util/pybind_helper.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,16 @@
2323
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
2424
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
2525
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
26-
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \
26+
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1) \
27+
.def("__reduce_ex__", \
28+
[](transformer_engine::DType self, pybind11::object /*protocol*/) { \
29+
return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \
30+
pybind11::make_tuple(static_cast<int>(self))); \
31+
}) \
32+
.def("__reduce__", [](transformer_engine::DType self) { \
33+
return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \
34+
pybind11::make_tuple(static_cast<int>(self))); \
35+
}); \
2736
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
2837
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
2938
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \

transformer_engine/pytorch/__init__.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,66 @@
8989
from transformer_engine.pytorch.tensor import MXFP8Tensor
9090
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor
9191
from transformer_engine.pytorch.tensor import NVFP4Tensor
92+
from transformer_engine.pytorch.tensor.float8_tensor import (
93+
_make_float8_tensor_in_reduce_ex,
94+
)
95+
from transformer_engine.pytorch.tensor.mxfp8_tensor import (
96+
_make_mxfp8_tensor_in_reduce_ex,
97+
)
98+
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
99+
_make_nvfp4_tensor_in_reduce_ex,
100+
)
101+
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
102+
_make_float8_blockwise_tensor_in_reduce_ex,
103+
)
92104

93105
try:
94106
torch._dynamo.config.error_on_nested_jit_trace = False
95107
except AttributeError:
96108
pass # error_on_nested_jit_trace was added in PyTorch 2.2.0
109+
110+
# To allow for safe unpickling of QuantizedTensors when using DCP
111+
# checkpointing with FSDP2. ``tex.DType`` (the pybind11 enum) has its
112+
# ``__reduce_ex__`` / ``__reduce__`` overridden in the C++ binding (see
113+
# ``transformer_engine/common/util/pybind_helper.h``) so its pickle
114+
# stream encodes as ``(tex.DType, (int,))`` and only the class itself
115+
# needs to be allow-listed below.
116+
try:
117+
from torch.serialization import add_safe_globals
118+
import transformer_engine_torch as tex
119+
120+
add_safe_globals(
121+
[
122+
# Storage mixins (used during pickling of internal-only tensors)
123+
QuantizedTensorStorage,
124+
Float8TensorStorage,
125+
MXFP8TensorStorage,
126+
NVFP4TensorStorage,
127+
Float8BlockwiseQTensorStorage,
128+
# Quantizer types embedded in metadata
129+
Quantizer,
130+
Float8Quantizer,
131+
Float8CurrentScalingQuantizer,
132+
MXFP8Quantizer,
133+
NVFP4Quantizer,
134+
Float8BlockQuantizer,
135+
# pybind11 enum used as Quantizer.dtype
136+
tex.DType,
137+
# __reduce_ex__ reconstructors (module-level functions).
138+
_make_float8_tensor_in_reduce_ex,
139+
_make_mxfp8_tensor_in_reduce_ex,
140+
_make_nvfp4_tensor_in_reduce_ex,
141+
_make_float8_blockwise_tensor_in_reduce_ex,
142+
]
143+
)
144+
except (ImportError, AttributeError):
145+
import warnings as _warnings
146+
147+
_warnings.warn(
148+
"transformer_engine: torch.serialization.add_safe_globals is "
149+
"unavailable on this PyTorch version (added in 2.4). DCP "
150+
"checkpointing of QuantizedTensor weights with FSDP2 will not "
151+
"work; upgrade to PyTorch >= 2.4 to enable it.",
152+
RuntimeWarning,
153+
stacklevel=2,
154+
)

transformer_engine/pytorch/module/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
4545
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
4646
from ..tensor.mxfp8_tensor import MXFP8Quantizer
47+
from ..tensor.nvfp4_tensor import NVFP4Quantizer
4748
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
4849
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
4950
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
@@ -1641,7 +1642,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
16411642
raise RuntimeError("Weight quantizer has not been initialized")
16421643
quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
16431644
quantizer.internal = False
1644-
if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer):
1645+
if is_dtensor and isinstance(
1646+
quantizer, (Float8CurrentScalingQuantizer, NVFP4Quantizer)
1647+
):
16451648
device_mesh = dtensor_param.device_mesh
16461649
amax_reduction_group = (
16471650
device_mesh.get_group(mesh_dim="shard")

transformer_engine/pytorch/quantized_tensor.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,26 @@ def half(self) -> torch.Tensor:
552552
# pylint: disable=missing-function-docstring
553553
return self.dequantize(dtype=torch.float16)
554554

555-
def cpu(self, memory_format=torch.preserve_format) -> torch.Tensor:
555+
def cpu(self, memory_format=torch.preserve_format) -> QuantizedTensor:
556+
"""Move tensor to CPU while preserving the QuantizedTensor type.
557+
558+
Routes through ``aten._to_copy.default`` so the subclass-preserving
559+
handler in ``__torch_dispatch__`` runs (rather than dequantizing).
560+
561+
"""
556562
# pylint: disable=missing-function-docstring
557-
return self.dequantize().cpu(memory_format=memory_format)
563+
return self.to(device=torch.device("cpu"), memory_format=memory_format)
564+
565+
def untyped_storage(self) -> torch.UntypedStorage:
566+
"""Return an empty UntypedStorage on the tensor's device.
567+
568+
``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
569+
backing storage of its own; the actual bytes live in the inner
570+
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
571+
an implementation detail of the quantization scheme. Need to define
572+
this method to avoid DCP staging errors with FSDP2.
573+
"""
574+
return torch.UntypedStorage(0, device=self.device)
558575

559576
def expand_as(self, other: torch.Tensor) -> torch.Tensor:
560577
# pylint: disable=missing-function-docstring
@@ -608,6 +625,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
608625
dst.copy_(src)
609626
return None
610627

628+
# _to_copy op (used by .to(device=...), .cpu(), DCP staging).
629+
# Preserve the QuantizedTensor subclass and move all internal
630+
# buffers (data, scales, etc.) to the requested device.
631+
if func == torch.ops.aten._to_copy.default:
632+
tensor = args[0]
633+
kw = dict(kwargs) if kwargs else {}
634+
dtype = kw.get("dtype", None)
635+
if dtype is None or dtype == tensor.dtype:
636+
target_device = kw.get("device", tensor.device) or tensor.device
637+
target_device = torch.device(target_device)
638+
pin_memory = bool(kw.get("pin_memory", False))
639+
non_blocking = bool(kw.get("non_blocking", False))
640+
new_metadata = {"device": target_device}
641+
# Update tensor storage metadata
642+
for key, value in tensor.get_metadata().items():
643+
if isinstance(value, torch.Tensor):
644+
value = value.to(device=target_device, non_blocking=non_blocking)
645+
if pin_memory and target_device.type == "cpu":
646+
value = value.pin_memory()
647+
new_metadata[key] = value
648+
# Update torch Tensor metadata
649+
new_metadata.update(
650+
{
651+
"dtype": tensor.dtype,
652+
"shape": tensor.shape,
653+
"requires_grad": tensor.requires_grad,
654+
}
655+
)
656+
return type(tensor)(**new_metadata)
657+
611658
# View op
612659
if func == torch.ops.aten.view.default:
613660
raise NotImplementedError("{cls.__name__} class does not support tensor views")
@@ -748,14 +795,19 @@ def make_like(
748795
"""Create new quantized tensor
749796
750797
By default, new tensor has the same attributes and underlying
751-
data. This function is intended to create view of tensors.
752-
798+
data. This function is intended to create a view of ``tensor``,
753799
"""
754800
shape = shape if shape is not None else tensor.shape
755801
dtype = dtype if dtype is not None else tensor.dtype
756802
kwargs = tensor.get_metadata()
757803
kwargs["fake_dtype"] = dtype
758-
return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs)
804+
return cls(
805+
shape=shape,
806+
dtype=dtype,
807+
requires_grad=requires_grad,
808+
device=tensor.device,
809+
**kwargs,
810+
)
759811

760812
def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor:
761813
"""Create `QuantizedTensor` with given nominal dtype

transformer_engine/pytorch/tensor/_quantization_helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def forward(
6161
kwargs = tensor.get_metadata()
6262
for key, val in init_kwargs.items():
6363
kwargs[key] = val
64+
kwargs["device"] = tensor.device
6465
return type(tensor)(tensor.shape, tensor.dtype, **kwargs)
6566

6667
@staticmethod

0 commit comments

Comments
 (0)