Skip to content

Commit 8743a23

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e29f26c commit 8743a23

2 files changed

Lines changed: 19 additions & 15 deletions

File tree

tests/pytorch/distributed/test_gtp.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,8 +1507,11 @@ class TestWaitAsyncCommsFallback:
15071507
"""
15081508

15091509
class _FakeGroup:
1510-
def size(self): return 1
1511-
def rank(self): return 0
1510+
def size(self):
1511+
return 1
1512+
1513+
def rank(self):
1514+
return 0
15121515

15131516
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
15141517
def test_fallback_accumulates_when_no_rs_handle(self):
@@ -1519,15 +1522,15 @@ def test_fallback_accumulates_when_no_rs_handle(self):
15191522
p.pad_length = 0
15201523
p.chain_id = gtp_module.GTPChain.UNGRAPHED.value
15211524
p._quantizer = None
1522-
p.is_routed_expert = False # ⇒ self._weights property returns [self]
1525+
p.is_routed_expert = False # ⇒ self._weights property returns [self]
15231526
p.main_grad = torch.zeros(8, 4, dtype=dtype, device="cuda")
1524-
p._prefetch_handle = None # _wait_param_gather is no-op
1525-
p._wgrad_rs_handle = None # _wait_reduce_scatter is no-op → fallback fires
1527+
p._prefetch_handle = None # _wait_param_gather is no-op
1528+
p._wgrad_rs_handle = None # _wait_reduce_scatter is no-op → fallback fires
15261529
p._cached_ag_stream = None
15271530
p._cached_rs_stream = None
15281531
p.ag_event = torch.cuda.Event(external=True)
15291532
p.rs_event = torch.cuda.Event(external=True)
1530-
p.rs_event.record() # so rs_event.wait() in fallback doesn't block
1533+
p.rs_event.record() # so rs_event.wait() in fallback doesn't block
15311534
p._already_finalized = False
15321535
p.grad_added_to_main_grad = False
15331536

@@ -1552,8 +1555,9 @@ def test_fallback_accumulates_when_no_rs_handle(self):
15521555
gtp_module._inflight_comm_params.update(saved)
15531556

15541557
torch.cuda.synchronize()
1555-
assert torch.all(p.main_grad == 2.0), \
1556-
f"main_grad should be 2.0 after fallback accumulation; got {p.main_grad}"
1558+
assert torch.all(
1559+
p.main_grad == 2.0
1560+
), f"main_grad should be 2.0 after fallback accumulation; got {p.main_grad}"
15571561
assert p._already_finalized is True, "_already_finalized must be set"
15581562
assert p.grad_added_to_main_grad is True, "grad_added_to_main_grad must be set"
15591563

@@ -1567,7 +1571,7 @@ def test_fallback_skipped_when_already_finalized(self):
15671571
p.pad_length = 0
15681572
p.chain_id = gtp_module.GTPChain.UNGRAPHED.value
15691573
p._quantizer = None
1570-
p.is_routed_expert = False # ⇒ self._weights property returns [self]
1574+
p.is_routed_expert = False # ⇒ self._weights property returns [self]
15711575
# Pre-existing main_grad with a value the fallback must NOT overwrite.
15721576
p.main_grad = torch.full((8, 4), 5.0, dtype=dtype, device="cuda")
15731577
p._prefetch_handle = None
@@ -1577,7 +1581,7 @@ def test_fallback_skipped_when_already_finalized(self):
15771581
p.ag_event = torch.cuda.Event(external=True)
15781582
p.rs_event = torch.cuda.Event(external=True)
15791583
p.rs_event.record()
1580-
p._already_finalized = True # ← short-circuits the fallback
1584+
p._already_finalized = True # ← short-circuits the fallback
15811585

15821586
# No _rs_ticket: if the fallback ran it would AttributeError on
15831587
# cache.get(None). The skip path must not touch the cache at all.
@@ -1597,5 +1601,6 @@ def test_fallback_skipped_when_already_finalized(self):
15971601
gtp_module._inflight_comm_params.update(saved)
15981602

15991603
torch.cuda.synchronize()
1600-
assert torch.all(p.main_grad == 5.0), \
1601-
"main_grad must be untouched when _already_finalized=True"
1604+
assert torch.all(
1605+
p.main_grad == 5.0
1606+
), "main_grad must be untouched when _already_finalized=True"

transformer_engine/pytorch/module/generalized_tensor_parallelism.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def classify_gtp_chains(model) -> None:
143143

144144
class GTPWeightState(Enum):
145145
"""State of a GTPShardedParam's AG / RS lifecycle (debug / stale-read guard)."""
146+
146147
NONE = "NONE" # Sharded, no pending operation
147148
ASYNC_WAIT = "ASYNC_WAIT" # Async all-gather in progress
148149
DATA_READY = "DATA_READY" # Async all-gather complete, result in cache
@@ -1304,9 +1305,7 @@ def _reduce_scatter(self, wgrads, async_op, nvtx_label=None):
13041305
async_ops=async_op,
13051306
) as cm:
13061307
for out_buffer, tensor in zip(out_buffers, wgrads):
1307-
out, _ = reduce_scatter_along_first_dim(
1308-
tensor, self.group, output=out_buffer
1309-
)
1308+
out, _ = reduce_scatter_along_first_dim(tensor, self.group, output=out_buffer)
13101309
outputs.append(out)
13111310
nvtx_range_pop(f"{nvtx_label}.batched_gtp_rs")
13121311

0 commit comments

Comments
 (0)