Skip to content

Commit 9435382

Browse files
committed
Fork the AG test to a separate module / test so NVFP4 and Float8Block are still model parity tested.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent af7362a commit 9435382

4 files changed

Lines changed: 299 additions & 54 deletions

File tree

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
#!/usr/bin/python3
2+
3+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
#
5+
# See LICENSE for license information.
6+
7+
"""
8+
Standalone test for FP8 FSDP2 all-gather correctness.
9+
10+
Verifies that FSDP2's internal all-gather of FP8 parameters produces the same
11+
result as a manual all-gather of dequantized FP32 values.
12+
"""
13+
14+
import argparse
15+
import os
16+
import sys
17+
from contextlib import nullcontext
18+
19+
import transformer_engine.pytorch as te
20+
import transformer_engine.common.recipe
21+
from transformer_engine.pytorch import fp8_model_init
22+
import torch
23+
import torch.distributed as dist
24+
import torch.nn.functional as F
25+
from torch import optim
26+
from torch.distributed.tensor import DTensor
27+
from torch.distributed._composable.fsdp import fully_shard
28+
from torch.distributed.device_mesh import init_device_mesh
29+
from torch import nn
30+
31+
LOCAL_RANK = None
32+
33+
# Fixed model dimensions — this test focuses on allgather correctness, not model flexibility.
34+
_NUM_HEADS = 8
35+
_HEAD_DIM = 64
36+
_HIDDEN_SIZE = _NUM_HEADS * _HEAD_DIM # 512
37+
_FFN_SIZE = _HIDDEN_SIZE * 4 # 2048
38+
_NUM_LAYERS = 2
39+
_BATCH_SIZE = 4
40+
_SEQ_LEN = 32
41+
42+
43+
def dist_print(msg):
44+
if LOCAL_RANK == 0:
45+
print(msg)
46+
47+
48+
def _parse_args():
49+
parser = argparse.ArgumentParser(
50+
description="Test FP8 FSDP2 all-gather correctness with TransformerLayer."
51+
)
52+
parser.add_argument(
53+
"--recipe",
54+
type=str,
55+
default="DelayedScaling",
56+
choices=[
57+
"DelayedScaling",
58+
"Float8CurrentScaling",
59+
"Float8BlockScaling",
60+
"MXFP8BlockScaling",
61+
"NVFP4BlockScaling",
62+
],
63+
)
64+
parser.add_argument(
65+
"--sharding-dims",
66+
type=int,
67+
nargs="+",
68+
required=True,
69+
help=(
70+
'Sharding mesh dimensions: ("dp_shard",), ("dp_replicate", "dp_shard"), '
71+
'or ("dp_replicate", "dp_shard", "tp")'
72+
),
73+
)
74+
parser.add_argument("--seed", type=int, default=42)
75+
args = parser.parse_args()
76+
assert len(args.sharding_dims) <= 3
77+
args.tp_size = args.sharding_dims[2] if len(args.sharding_dims) >= 3 else 1
78+
return args
79+
80+
81+
def _get_recipe(name):
82+
return getattr(transformer_engine.common.recipe, name)()
83+
84+
85+
def _get_device_mesh(world_size, sharding_dims):
86+
dist_print(f"sharding-dims: {sharding_dims}")
87+
if len(sharding_dims) == 1:
88+
assert sharding_dims[0] == world_size
89+
return init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp_shard",))
90+
elif len(sharding_dims) == 2:
91+
assert sharding_dims[0] * sharding_dims[1] == world_size
92+
return init_device_mesh(
93+
"cuda",
94+
(sharding_dims[0], sharding_dims[1]),
95+
mesh_dim_names=("dp_replicate", "dp_shard"),
96+
)
97+
else:
98+
assert sharding_dims[0] * sharding_dims[1] * sharding_dims[2] == world_size
99+
return init_device_mesh(
100+
"cuda",
101+
(sharding_dims[0], sharding_dims[1], sharding_dims[2]),
102+
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
103+
)
104+
105+
106+
def _build_model(args):
107+
kwargs = {
108+
"params_dtype": torch.float32,
109+
"device": "meta",
110+
"tp_size": args.tp_size,
111+
"fuse_qkv_params": True,
112+
}
113+
if args.tp_size > 1:
114+
kwargs["tp_mesh"] = args.mesh["tp"]
115+
kwargs["weight_mesh"] = args.mesh["dp_shard", "tp"]._flatten("weight_mesh")
116+
kwargs["set_parallel_mode"] = True
117+
elif "dp_replicate" in args.mesh.mesh_dim_names:
118+
kwargs["weight_mesh"] = args.mesh["dp_shard"]
119+
120+
model = nn.Sequential(
121+
*[
122+
te.TransformerLayer(_HIDDEN_SIZE, _FFN_SIZE, _NUM_HEADS, **kwargs)
123+
for _ in range(_NUM_LAYERS)
124+
]
125+
)
126+
inp_shape = [_SEQ_LEN, _BATCH_SIZE, _HIDDEN_SIZE]
127+
return model, inp_shape
128+
129+
130+
def _shard_model(model, mesh):
131+
dp_dims = (
132+
("dp_replicate", "dp_shard") if "dp_replicate" in mesh.mesh_dim_names else ("dp_shard",)
133+
)
134+
for child in model.children():
135+
fully_shard(child, mesh=mesh[dp_dims])
136+
fully_shard(model, mesh=mesh[dp_dims])
137+
return model
138+
139+
140+
@torch.no_grad()
141+
def _test_fp8_fsdp2_allgather(model):
142+
"""
143+
Compare the result of the FP8 AG by FSDP2 with a manual AG in FP32
144+
after dequantizing the FP8 values.
145+
"""
146+
# FP32 manual weight allgather
147+
fp32_allgathered_params = {}
148+
for name, param in model.named_parameters():
149+
assert isinstance(
150+
param, DTensor
151+
), f"[test_fp8_fsdp2_allgather] {param} should be a DTensor."
152+
local_tensor = param._local_tensor
153+
device_mesh = param.device_mesh
154+
dist_group = (
155+
device_mesh.get_group(mesh_dim="dp_shard")
156+
if device_mesh.ndim > 1
157+
else device_mesh.get_group()
158+
)
159+
# Perform manual allgather on local_tensor. zeros_like will create hp tensor since torch_dispatch
160+
# for local_tensor will go down the dequantization route.
161+
gathered_tensor = [
162+
torch.zeros_like(local_tensor) for _ in range(dist.get_world_size(group=dist_group))
163+
]
164+
dist.all_gather(gathered_tensor, local_tensor.dequantize(), group=dist_group)
165+
full_tensor = torch.cat(gathered_tensor, dim=0)
166+
fp32_allgathered_params[name] = full_tensor
167+
# FP8 allgather using FSDP2
168+
for module in model.modules():
169+
# Not all modules are wrapped/sharded with FSDP2.
170+
if hasattr(module, "unshard"):
171+
module.unshard()
172+
# Make sure allgathered parameters match exactly
173+
for name, param in model.named_parameters():
174+
if isinstance(param, DTensor):
175+
# Will still be a DTensor in the case of TP, even after FSDP2 AG,
176+
# because we wrap our weights as DTensor shards over the TP group.
177+
param = param._local_tensor
178+
torch.testing.assert_close(param.dequantize(), fp32_allgathered_params[name])
179+
# Revert model to original sharded state
180+
for module in model.modules():
181+
# Not all modules are wrapped/sharded with FSDP2.
182+
if hasattr(module, "reshard"):
183+
module.reshard()
184+
185+
186+
def _main(args):
187+
global LOCAL_RANK
188+
assert "TORCHELASTIC_RUN_ID" in os.environ
189+
WORLD_RANK = int(os.getenv("RANK", "0"))
190+
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
191+
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
192+
193+
torch.cuda.set_device(WORLD_RANK)
194+
torch.manual_seed(args.seed)
195+
torch.cuda.manual_seed(args.seed)
196+
197+
dist.init_process_group(backend="nccl", rank=WORLD_RANK, world_size=WORLD_SIZE)
198+
device = torch.device(f"cuda:{LOCAL_RANK}")
199+
200+
mesh = _get_device_mesh(WORLD_SIZE, args.sharding_dims)
201+
args.mesh = mesh
202+
203+
fp8_recipe = _get_recipe(args.recipe)
204+
205+
with fp8_model_init(enabled=True, recipe=fp8_recipe):
206+
model, inp_shape = _build_model(args)
207+
208+
model = _shard_model(model, mesh)
209+
210+
for module in model.modules():
211+
if hasattr(module, "reset_parameters"):
212+
module.reset_parameters()
213+
214+
# Run a training step to initialize FSDP2 lazy state and update quantization
215+
# scales before testing the allgather. Block-scaling formats (Float8BlockScaling,
216+
# NVFP4BlockScaling) only exhibit allgather inconsistencies after weight updates.
217+
input_data = torch.randn(inp_shape, device=device)
218+
target = torch.randn(inp_shape, device=device)
219+
nvfp4_ctx = (
220+
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
221+
if args.recipe == "NVFP4BlockScaling"
222+
else nullcontext()
223+
)
224+
optimizer = optim.Adam(model.parameters(), lr=1e-3)
225+
optimizer.zero_grad()
226+
with nvfp4_ctx, te.autocast(enabled=True, recipe=fp8_recipe):
227+
output = model(input_data)
228+
loss = F.mse_loss(output, target)
229+
loss.backward()
230+
optimizer.step()
231+
232+
_test_fp8_fsdp2_allgather(model)
233+
dist_print("test_fp8_fsdp2_allgather passed.")
234+
235+
dist.destroy_process_group()
236+
return 0
237+
238+
239+
if __name__ == "__main__":
240+
sys.exit(_main(_parse_args()))

tests/pytorch/distributed/run_fsdp2_model.py

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -335,52 +335,6 @@ def restore_custom_attrs(module, custom_attrs):
335335
setattr(param, attr_name, attr_value)
336336

337337

338-
@torch.no_grad()
339-
def test_fp8_fsdp2_allgather(model):
340-
"""
341-
Compare the result of the FP8 AG by FSDP2 with a manual AG in FP32
342-
after dequantizing the FP8 values.
343-
"""
344-
# FP32 manual weight allgather
345-
fp32_allgathered_params = {}
346-
for name, param in model.named_parameters():
347-
assert isinstance(
348-
param, DTensor
349-
), f"[test_fp8_fsdp2_allgather] {param} should be a DTensor."
350-
local_tensor = param._local_tensor
351-
device_mesh = param.device_mesh
352-
dist_group = (
353-
device_mesh.get_group(mesh_dim="dp_shard")
354-
if device_mesh.ndim > 1
355-
else device_mesh.get_group()
356-
)
357-
# Perform manual allgather on local_tensor. zeros_like will create hp tensor since torch_dispatch
358-
# for local_tensor will go down the dequantization route.
359-
gathered_tensor = [
360-
torch.zeros_like(local_tensor) for _ in range(dist.get_world_size(group=dist_group))
361-
]
362-
dist.all_gather(gathered_tensor, local_tensor.dequantize(), group=dist_group)
363-
full_tensor = torch.cat(gathered_tensor, dim=0)
364-
fp32_allgathered_params[name] = full_tensor
365-
# FP8 allgather using FSDP2
366-
for module in model.modules():
367-
# Not all modules are wrapped/sharded with FSDP2.
368-
if hasattr(module, "unshard"):
369-
module.unshard()
370-
# Make sure allgathered parameters match exactly
371-
for name, param in model.named_parameters():
372-
if isinstance(param, DTensor):
373-
# Will still be a DTensor in the case of TP, even after FSDP2 AG,
374-
# because we wrap our weights as DTensor shards over the TP group.
375-
param = param._local_tensor
376-
torch.testing.assert_close(param.dequantize(), fp32_allgathered_params[name])
377-
# Revert model to original sharded state
378-
for module in model.modules():
379-
# Not all modules are wrapped/sharded with FSDP2.
380-
if hasattr(module, "reshard"):
381-
module.reshard()
382-
383-
384338
def _train(args):
385339
"""
386340
Torch Distributed Initialization
@@ -488,11 +442,6 @@ def _train(args):
488442
optimizer.step()
489443
dist_print(f"Iteration {iteration} completed with loss {loss.item()}")
490444

491-
# Some of the FSDP states are lazy initialized during FSDP forward pass
492-
# so testing fp8 allgather at the end of the training loop.
493-
if args.fp8_init and args.recipe not in ("Float8BlockScaling", "NVFP4BlockScaling"):
494-
test_fp8_fsdp2_allgather(model)
495-
496445
"""
497446
DCP Checkpoint Testing
498447
"""
@@ -560,9 +509,9 @@ def _train(args):
560509
v_pt = s_post_train[key]
561510
if isinstance(v_pt, DTensor):
562511
v_pt = v_pt.to_local()
563-
assert not torch.allclose(v1, v_pt), (
564-
f"[{key}] Model weights should have changed after extra training steps"
565-
)
512+
assert not torch.allclose(
513+
v1, v_pt
514+
), f"[{key}] Model weights should have changed after extra training steps"
566515

567516
# Load the checkpoint.
568517
state_dict = {"app": AppState(model=model, optimizer=optimizer)}

tests/pytorch/distributed/test_torch_fsdp2.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,53 @@ def test_distributed(fp8_init, sharding_dims, fp_recipe, layer_type):
101101
_run_test(fp8_init, sharding_dims, fp_recipe, layer_type)
102102

103103

104+
## ── FP8 FSDP2 all-gather correctness test ───────────────────────────
105+
106+
107+
def _run_allgather_test(sharding_dims, recipe):
108+
test_path = Path(__file__).parent.resolve() / "run_fsdp2_allgather.py"
109+
test_cmd = [
110+
"torchrun",
111+
f"--nproc_per_node={NUM_PROCS}",
112+
str(test_path),
113+
"--sharding-dims",
114+
*[str(x) for x in sharding_dims],
115+
"--recipe",
116+
recipe,
117+
]
118+
subprocess.run(test_cmd, env=os.environ, check=True)
119+
120+
121+
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs.")
122+
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
123+
@pytest.mark.parametrize(
124+
"sharding_dims",
125+
(
126+
# FSDP
127+
[NUM_PROCS],
128+
# HSDP
129+
[2, NUM_PROCS // 2],
130+
# (H/F)SDP-TP
131+
[NUM_PROCS // 4, 2, 2],
132+
),
133+
)
134+
def test_fp8_fsdp2_allgather(sharding_dims, fp_recipe):
135+
"""Verify FSDP2 FP8 all-gather matches a manual dequantize-then-gather reference."""
136+
if fp_recipe in ("Float8BlockScaling", "NVFP4BlockScaling"):
137+
pytest.xfail(
138+
f"{fp_recipe}: block-scaled quantization formats are not supported by the "
139+
"FP8 FSDP2 all-gather correctness test."
140+
)
141+
142+
parallel_size = math.prod(x for x in sharding_dims if x != 0)
143+
if NUM_PROCS < parallel_size:
144+
pytest.skip(
145+
f"Insufficient devices ({NUM_PROCS}) to test sharding configuration: {sharding_dims}"
146+
)
147+
148+
_run_allgather_test(sharding_dims, fp_recipe)
149+
150+
104151
## ── FusedAdam + FSDP2 tests ─────────────────────────────────────────
105152

106153

transformer_engine/pytorch/distributed.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,13 +2035,22 @@ class _ToLocalIdentity(torch.autograd.Function):
20352035

20362036
@staticmethod
20372037
def forward(ctx, dtensor_param: DTensor) -> torch.Tensor:
2038+
"""
2039+
Forward implementation for DTensor.to_local().
2040+
For quantized parameters, does not shallow copy
2041+
the local Tensor.
2042+
"""
20382043
ctx.device_mesh = dtensor_param.device_mesh
20392044
ctx.placements = dtensor_param.placements
20402045
ctx.set_materialize_grads(False)
20412046
return dtensor_param._local_tensor
20422047

20432048
@staticmethod
20442049
def backward(ctx, grad_local):
2050+
"""
2051+
Backward implementation for DTensor.to_local().
2052+
Converts Tensor gradients to DTensor.
2053+
"""
20452054
if grad_local is None:
20462055
return None
20472056
return DTensor.from_local(

0 commit comments

Comments
 (0)