@@ -1507,8 +1507,11 @@ class TestWaitAsyncCommsFallback:
15071507 """
15081508
15091509 class _FakeGroup :
1510- def size (self ): return 1
1511- def rank (self ): return 0
1510+ def size (self ):
1511+ return 1
1512+
1513+ def rank (self ):
1514+ return 0
15121515
15131516 @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA required" )
15141517 def test_fallback_accumulates_when_no_rs_handle (self ):
@@ -1519,15 +1522,15 @@ def test_fallback_accumulates_when_no_rs_handle(self):
15191522 p .pad_length = 0
15201523 p .chain_id = gtp_module .GTPChain .UNGRAPHED .value
15211524 p ._quantizer = None
1522- p .is_routed_expert = False # ⇒ self._weights property returns [self]
1525+ p .is_routed_expert = False # ⇒ self._weights property returns [self]
15231526 p .main_grad = torch .zeros (8 , 4 , dtype = dtype , device = "cuda" )
1524- p ._prefetch_handle = None # _wait_param_gather is no-op
1525- p ._wgrad_rs_handle = None # _wait_reduce_scatter is no-op → fallback fires
1527+ p ._prefetch_handle = None # _wait_param_gather is no-op
1528+ p ._wgrad_rs_handle = None # _wait_reduce_scatter is no-op → fallback fires
15261529 p ._cached_ag_stream = None
15271530 p ._cached_rs_stream = None
15281531 p .ag_event = torch .cuda .Event (external = True )
15291532 p .rs_event = torch .cuda .Event (external = True )
1530- p .rs_event .record () # so rs_event.wait() in fallback doesn't block
1533+ p .rs_event .record () # so rs_event.wait() in fallback doesn't block
15311534 p ._already_finalized = False
15321535 p .grad_added_to_main_grad = False
15331536
@@ -1552,8 +1555,9 @@ def test_fallback_accumulates_when_no_rs_handle(self):
15521555 gtp_module ._inflight_comm_params .update (saved )
15531556
15541557 torch .cuda .synchronize ()
1555- assert torch .all (p .main_grad == 2.0 ), \
1556- f"main_grad should be 2.0 after fallback accumulation; got { p .main_grad } "
1558+ assert torch .all (
1559+ p .main_grad == 2.0
1560+ ), f"main_grad should be 2.0 after fallback accumulation; got { p .main_grad } "
15571561 assert p ._already_finalized is True , "_already_finalized must be set"
15581562 assert p .grad_added_to_main_grad is True , "grad_added_to_main_grad must be set"
15591563
@@ -1567,7 +1571,7 @@ def test_fallback_skipped_when_already_finalized(self):
15671571 p .pad_length = 0
15681572 p .chain_id = gtp_module .GTPChain .UNGRAPHED .value
15691573 p ._quantizer = None
1570- p .is_routed_expert = False # ⇒ self._weights property returns [self]
1574+ p .is_routed_expert = False # ⇒ self._weights property returns [self]
15711575 # Pre-existing main_grad with a value the fallback must NOT overwrite.
15721576 p .main_grad = torch .full ((8 , 4 ), 5.0 , dtype = dtype , device = "cuda" )
15731577 p ._prefetch_handle = None
@@ -1577,7 +1581,7 @@ def test_fallback_skipped_when_already_finalized(self):
15771581 p .ag_event = torch .cuda .Event (external = True )
15781582 p .rs_event = torch .cuda .Event (external = True )
15791583 p .rs_event .record ()
1580- p ._already_finalized = True # ← short-circuits the fallback
1584+ p ._already_finalized = True # ← short-circuits the fallback
15811585
15821586 # No _rs_ticket: if the fallback ran it would AttributeError on
15831587 # cache.get(None). The skip path must not touch the cache at all.
@@ -1597,5 +1601,6 @@ def test_fallback_skipped_when_already_finalized(self):
15971601 gtp_module ._inflight_comm_params .update (saved )
15981602
15991603 torch .cuda .synchronize ()
1600- assert torch .all (p .main_grad == 5.0 ), \
1601- "main_grad must be untouched when _already_finalized=True"
1604+ assert torch .all (
1605+ p .main_grad == 5.0
1606+ ), "main_grad must be untouched when _already_finalized=True"
0 commit comments