Skip to content

Commit e29f26c

Browse files
committed
fix lint; fix comments
1 parent fc8a940 commit e29f26c

6 files changed

Lines changed: 218 additions & 58 deletions

File tree

tests/pytorch/distributed/test_gtp.py

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# See LICENSE for license information.
44

@@ -28,6 +28,7 @@
2828
20. TestGTPPrefetchDisabled – weight_prefetch=False: single-pass forward still works (multi-GPU)
2929
21. TestFuseWgradAccumulation – fuse_wgrad_accumulation=True: wgrad→main_grad (multi-GPU)
3030
22. TestGTPGradAccumHook – main_grad updated after reduce-scatter backward (multi-GPU)
31+
23. TestWaitAsyncCommsFallback – wait_async_comms(finalize_after_drain=True) inline-accumulation fallback when _wgrad_rs_handle is None (single-process)
3132
3233
Multi-GPU tests use torch.multiprocessing.spawn and are skipped when fewer
3334
than the required CUDA devices are available.
@@ -71,8 +72,6 @@ def reset_fp8_state():
7172
def reset_gtp_globals():
7273
"""Reset all GTP mutable class/module-level state between tests."""
7374
yield
74-
GTPShardedParam._first_weight_flag = True
75-
GTPShardedParam._pending_rs_weight = None
7675
GTPShardedParam._chain_state = {}
7776

7877

@@ -1486,3 +1485,117 @@ class TestGTPGradAccumHook:
14861485
def test_main_grad_updated_after_backward(self):
14871486
_requires_multi_gpu(4)
14881487
_run_distributed(_worker_main_grad_updated_after_bwd, 4)
1488+
1489+
1490+
# ---------------------------------------------------------------------------
1491+
# 24. wait_async_comms(finalize_after_drain=True) inline-accumulation fallback
1492+
# ---------------------------------------------------------------------------
1493+
1494+
1495+
class TestWaitAsyncCommsFallback:
1496+
"""Exercises the inline-accumulation fallback inside
1497+
``wait_async_comms(finalize_after_drain=True)``: when a param is in
1498+
``_inflight_comm_params`` (async AG was issued) but its ``_wgrad_rs_handle``
1499+
is ``None`` (no async RS handle to drain), the inner
1500+
``_wait_reduce_scatter`` call no-ops and the outer loop must inline the
1501+
accumulation itself (main_grad.add_ + ticket release + flag set).
1502+
1503+
Production flows rarely hit this combination — chain-interior params have
1504+
both async AG and async RS, and chain-head sync RS doesn't enter
1505+
``_inflight_comm_params`` via bwd AG. We construct the state by hand to
1506+
pin down the fallback's contract.
1507+
"""
1508+
1509+
class _FakeGroup:
1510+
def size(self): return 1
1511+
def rank(self): return 0
1512+
1513+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
1514+
def test_fallback_accumulates_when_no_rs_handle(self):
1515+
dtype = torch.bfloat16
1516+
p = GTPShardedParam(torch.zeros(8, 4, dtype=dtype, device="cuda"))
1517+
p.group = self._FakeGroup()
1518+
p.expert_idx = None
1519+
p.pad_length = 0
1520+
p.chain_id = gtp_module.GTPChain.UNGRAPHED.value
1521+
p._quantizer = None
1522+
p.is_routed_expert = False # ⇒ self._weights property returns [self]
1523+
p.main_grad = torch.zeros(8, 4, dtype=dtype, device="cuda")
1524+
p._prefetch_handle = None # _wait_param_gather is no-op
1525+
p._wgrad_rs_handle = None # _wait_reduce_scatter is no-op → fallback fires
1526+
p._cached_ag_stream = None
1527+
p._cached_rs_stream = None
1528+
p.ag_event = torch.cuda.Event(external=True)
1529+
p.rs_event = torch.cuda.Event(external=True)
1530+
p.rs_event.record() # so rs_event.wait() in fallback doesn't block
1531+
p._already_finalized = False
1532+
p.grad_added_to_main_grad = False
1533+
1534+
# Place a known wgrad in the cache for the fallback to read.
1535+
cache = gtp_module.get_global_GTP_cache()
1536+
p._rs_ticket = cache.reserve(p, dtype, fwd=False, reduce_scatter=True)
1537+
cache.get(p._rs_ticket).fill_(2.0)
1538+
1539+
# Save + replace _inflight_comm_params so we don't trip over leftover
1540+
# params from earlier tests in the loop.
1541+
saved = set(gtp_module._inflight_comm_params)
1542+
gtp_module._inflight_comm_params.clear()
1543+
gtp_module._inflight_comm_params.add(p)
1544+
try:
1545+
gtp_module.wait_async_comms(
1546+
chain_id=p.chain_id,
1547+
skip_rs=False,
1548+
finalize_after_drain=True,
1549+
)
1550+
finally:
1551+
gtp_module._inflight_comm_params.clear()
1552+
gtp_module._inflight_comm_params.update(saved)
1553+
1554+
torch.cuda.synchronize()
1555+
assert torch.all(p.main_grad == 2.0), \
1556+
f"main_grad should be 2.0 after fallback accumulation; got {p.main_grad}"
1557+
assert p._already_finalized is True, "_already_finalized must be set"
1558+
assert p.grad_added_to_main_grad is True, "grad_added_to_main_grad must be set"
1559+
1560+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
1561+
def test_fallback_skipped_when_already_finalized(self):
1562+
"""When _already_finalized=True, the fallback must NOT re-accumulate."""
1563+
dtype = torch.bfloat16
1564+
p = GTPShardedParam(torch.zeros(8, 4, dtype=dtype, device="cuda"))
1565+
p.group = self._FakeGroup()
1566+
p.expert_idx = None
1567+
p.pad_length = 0
1568+
p.chain_id = gtp_module.GTPChain.UNGRAPHED.value
1569+
p._quantizer = None
1570+
p.is_routed_expert = False # ⇒ self._weights property returns [self]
1571+
# Pre-existing main_grad with a value the fallback must NOT overwrite.
1572+
p.main_grad = torch.full((8, 4), 5.0, dtype=dtype, device="cuda")
1573+
p._prefetch_handle = None
1574+
p._wgrad_rs_handle = None
1575+
p._cached_ag_stream = None
1576+
p._cached_rs_stream = None
1577+
p.ag_event = torch.cuda.Event(external=True)
1578+
p.rs_event = torch.cuda.Event(external=True)
1579+
p.rs_event.record()
1580+
p._already_finalized = True # ← short-circuits the fallback
1581+
1582+
# No _rs_ticket: if the fallback ran it would AttributeError on
1583+
# cache.get(None). The skip path must not touch the cache at all.
1584+
p._rs_ticket = None
1585+
1586+
saved = set(gtp_module._inflight_comm_params)
1587+
gtp_module._inflight_comm_params.clear()
1588+
gtp_module._inflight_comm_params.add(p)
1589+
try:
1590+
gtp_module.wait_async_comms(
1591+
chain_id=p.chain_id,
1592+
skip_rs=False,
1593+
finalize_after_drain=True,
1594+
)
1595+
finally:
1596+
gtp_module._inflight_comm_params.clear()
1597+
gtp_module._inflight_comm_params.update(saved)
1598+
1599+
torch.cuda.synchronize()
1600+
assert torch.all(p.main_grad == 5.0), \
1601+
"main_grad must be untouched when _already_finalized=True"

tests/pytorch/distributed/test_tp_gtp.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# See LICENSE for license information.
44

@@ -17,7 +17,6 @@
1717
2. TestTPGTPColumnParallelLinear – column-parallel Linear: weight shape + fwd/bwd correctness
1818
3. TestTPGTPRowParallelLinear – row-parallel Linear: weight shape + fwd/bwd smoke test
1919
4. TestTPGTPLayerNormLinear – LayerNormLinear column-parallel smoke test
20-
5. TestTPGTPLayerNormMLP – LayerNormMLP (column FC1 + row FC2) smoke test
2120
2221
Tests use (tp_size, gtp_size) = (2, 2) → world_size = 4 (runs on 4-GPU machines).
2322
@@ -53,8 +52,6 @@ def reset_fp8_state():
5352
def reset_gtp_globals():
5453
"""Reset GTP mutable class/module-level state between tests."""
5554
yield
56-
GTPShardedParam._first_weight_flag = True
57-
GTPShardedParam._pending_rs_weight = None
5855
GTPShardedParam._chain_state = {}
5956

6057

transformer_engine/pytorch/distributed.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,15 +1284,13 @@ def _post_process_nvfp4_gather(
12841284
handle.wait()
12851285
handle = None
12861286

1287-
# TODO
1288-
# # Fix the interleaved transposed data from gathering along first dim.
1289-
# out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size)
1290-
# out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size)
1287+
# Fix the interleaved transposed data from gathering along first dim.
1288+
# In-place .copy_() (not `=` rebind) to keep the storage address stable
1289+
# for CUDA graph capture — replays see the same pointer they captured.
12911290
out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size))
12921291
out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size))
12931292

1294-
# # Optionally pad the scaling inverse if needed.
1295-
# out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)
1293+
# Optionally pad the scaling inverse if needed (same in-place pattern).
12961294
out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv))
12971295

12981296

@@ -1308,6 +1306,10 @@ class _NVFP4AllGatherAsyncHandle:
13081306
_synchronized: bool = False
13091307

13101308
def post_process_nvfp4_gather(self) -> None:
1309+
"""Fix interleaved transposed data + pad scale_inv after the async AG completes.
1310+
1311+
Idempotent: gated by ``_synchronized`` in :meth:`wait`.
1312+
"""
13111313
_post_process_nvfp4_gather(
13121314
self.output,
13131315
self.columnwise_data_interleaved,
@@ -1454,9 +1456,8 @@ def _all_gather_nvfp4(
14541456
group=process_group,
14551457
)
14561458

1457-
# Transfer amax to output.
1458-
# TODO: jiemingz
1459-
# out._amax_rowwise = inp._amax_rowwise
1459+
# Transfer amax to output via in-place .copy_() so the storage
1460+
# address stays stable for CUDA graph capture.
14601461
out._amax_rowwise.copy_(inp._amax_rowwise)
14611462

14621463
# Gather the transposed NVFP4 data along first dimension. Fix format later.

0 commit comments

Comments
 (0)