1- # Copyright (c) 2022-2025 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ # Copyright (c) 2022-2026 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22#
33# See LICENSE for license information.
44
282820. TestGTPPrefetchDisabled – weight_prefetch=False: single-pass forward still works (multi-GPU)
292921. TestFuseWgradAccumulation – fuse_wgrad_accumulation=True: wgrad→main_grad (multi-GPU)
303022. TestGTPGradAccumHook – main_grad updated after reduce-scatter backward (multi-GPU)
31+ 23. TestWaitAsyncCommsFallback – wait_async_comms(finalize_after_drain=True) inline-accumulation fallback when _wgrad_rs_handle is None (single-process)
3132
3233Multi-GPU tests use torch.multiprocessing.spawn and are skipped when fewer
3334than the required CUDA devices are available.
@@ -71,8 +72,6 @@ def reset_fp8_state():
7172def reset_gtp_globals ():
7273 """Reset all GTP mutable class/module-level state between tests."""
7374 yield
74- GTPShardedParam ._first_weight_flag = True
75- GTPShardedParam ._pending_rs_weight = None
7675 GTPShardedParam ._chain_state = {}
7776
7877
@@ -1486,3 +1485,117 @@ class TestGTPGradAccumHook:
14861485 def test_main_grad_updated_after_backward (self ):
14871486 _requires_multi_gpu (4 )
14881487 _run_distributed (_worker_main_grad_updated_after_bwd , 4 )
1488+
1489+
1490+ # ---------------------------------------------------------------------------
1491+ # 24. wait_async_comms(finalize_after_drain=True) inline-accumulation fallback
1492+ # ---------------------------------------------------------------------------
1493+
1494+
1495+ class TestWaitAsyncCommsFallback :
1496+ """Exercises the inline-accumulation fallback inside
1497+ ``wait_async_comms(finalize_after_drain=True)``: when a param is in
1498+ ``_inflight_comm_params`` (async AG was issued) but its ``_wgrad_rs_handle``
1499+ is ``None`` (no async RS handle to drain), the inner
1500+ ``_wait_reduce_scatter`` call no-ops and the outer loop must inline the
1501+ accumulation itself (main_grad.add_ + ticket release + flag set).
1502+
1503+ Production flows rarely hit this combination — chain-interior params have
1504+ both async AG and async RS, and chain-head sync RS doesn't enter
1505+ ``_inflight_comm_params`` via bwd AG. We construct the state by hand to
1506+ pin down the fallback's contract.
1507+ """
1508+
1509+ class _FakeGroup :
1510+ def size (self ): return 1
1511+ def rank (self ): return 0
1512+
1513+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA required" )
1514+ def test_fallback_accumulates_when_no_rs_handle (self ):
1515+ dtype = torch .bfloat16
1516+ p = GTPShardedParam (torch .zeros (8 , 4 , dtype = dtype , device = "cuda" ))
1517+ p .group = self ._FakeGroup ()
1518+ p .expert_idx = None
1519+ p .pad_length = 0
1520+ p .chain_id = gtp_module .GTPChain .UNGRAPHED .value
1521+ p ._quantizer = None
1522+ p .is_routed_expert = False # ⇒ self._weights property returns [self]
1523+ p .main_grad = torch .zeros (8 , 4 , dtype = dtype , device = "cuda" )
1524+ p ._prefetch_handle = None # _wait_param_gather is no-op
1525+ p ._wgrad_rs_handle = None # _wait_reduce_scatter is no-op → fallback fires
1526+ p ._cached_ag_stream = None
1527+ p ._cached_rs_stream = None
1528+ p .ag_event = torch .cuda .Event (external = True )
1529+ p .rs_event = torch .cuda .Event (external = True )
1530+ p .rs_event .record () # so rs_event.wait() in fallback doesn't block
1531+ p ._already_finalized = False
1532+ p .grad_added_to_main_grad = False
1533+
1534+ # Place a known wgrad in the cache for the fallback to read.
1535+ cache = gtp_module .get_global_GTP_cache ()
1536+ p ._rs_ticket = cache .reserve (p , dtype , fwd = False , reduce_scatter = True )
1537+ cache .get (p ._rs_ticket ).fill_ (2.0 )
1538+
1539+ # Save + replace _inflight_comm_params so we don't trip over leftover
1540+ # params from earlier tests in the loop.
1541+ saved = set (gtp_module ._inflight_comm_params )
1542+ gtp_module ._inflight_comm_params .clear ()
1543+ gtp_module ._inflight_comm_params .add (p )
1544+ try :
1545+ gtp_module .wait_async_comms (
1546+ chain_id = p .chain_id ,
1547+ skip_rs = False ,
1548+ finalize_after_drain = True ,
1549+ )
1550+ finally :
1551+ gtp_module ._inflight_comm_params .clear ()
1552+ gtp_module ._inflight_comm_params .update (saved )
1553+
1554+ torch .cuda .synchronize ()
1555+ assert torch .all (p .main_grad == 2.0 ), \
1556+ f"main_grad should be 2.0 after fallback accumulation; got { p .main_grad } "
1557+ assert p ._already_finalized is True , "_already_finalized must be set"
1558+ assert p .grad_added_to_main_grad is True , "grad_added_to_main_grad must be set"
1559+
1560+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA required" )
1561+ def test_fallback_skipped_when_already_finalized (self ):
1562+ """When _already_finalized=True, the fallback must NOT re-accumulate."""
1563+ dtype = torch .bfloat16
1564+ p = GTPShardedParam (torch .zeros (8 , 4 , dtype = dtype , device = "cuda" ))
1565+ p .group = self ._FakeGroup ()
1566+ p .expert_idx = None
1567+ p .pad_length = 0
1568+ p .chain_id = gtp_module .GTPChain .UNGRAPHED .value
1569+ p ._quantizer = None
1570+ p .is_routed_expert = False # ⇒ self._weights property returns [self]
1571+ # Pre-existing main_grad with a value the fallback must NOT overwrite.
1572+ p .main_grad = torch .full ((8 , 4 ), 5.0 , dtype = dtype , device = "cuda" )
1573+ p ._prefetch_handle = None
1574+ p ._wgrad_rs_handle = None
1575+ p ._cached_ag_stream = None
1576+ p ._cached_rs_stream = None
1577+ p .ag_event = torch .cuda .Event (external = True )
1578+ p .rs_event = torch .cuda .Event (external = True )
1579+ p .rs_event .record ()
1580+ p ._already_finalized = True # ← short-circuits the fallback
1581+
1582+ # No _rs_ticket: if the fallback ran it would AttributeError on
1583+ # cache.get(None). The skip path must not touch the cache at all.
1584+ p ._rs_ticket = None
1585+
1586+ saved = set (gtp_module ._inflight_comm_params )
1587+ gtp_module ._inflight_comm_params .clear ()
1588+ gtp_module ._inflight_comm_params .add (p )
1589+ try :
1590+ gtp_module .wait_async_comms (
1591+ chain_id = p .chain_id ,
1592+ skip_rs = False ,
1593+ finalize_after_drain = True ,
1594+ )
1595+ finally :
1596+ gtp_module ._inflight_comm_params .clear ()
1597+ gtp_module ._inflight_comm_params .update (saved )
1598+
1599+ torch .cuda .synchronize ()
1600+ assert torch .all (p .main_grad == 5.0 ), \
1601+ "main_grad must be untouched when _already_finalized=True"
0 commit comments