44Run with:
55 torchrun --nproc_per_node=2 source/tests/pt/test_sezm_moe_a2a_multigpu.py
66 torchrun --nproc_per_node=4 source/tests/pt/test_sezm_moe_a2a_multigpu.py
7+ torchrun --nproc_per_node=8 source/tests/pt/test_sezm_moe_a2a_multigpu.py
78"""
89
910import unittest
1920def setup_dist ():
2021 """Initialize distributed environment."""
2122 if not dist .is_initialized ():
22- dist .init_process_group (backend = "gloo" )
23+ backend = "nccl" if torch .cuda .is_available () else "gloo"
24+ dist .init_process_group (backend = backend )
2325 rank = dist .get_rank ()
2426 world_size = dist .get_world_size ()
25- # Use CPU for multi-GPU tests (gloo backend)
26- device = torch .device ("cpu" )
27+ if torch .cuda .is_available ():
28+ torch .cuda .set_device (rank % torch .cuda .device_count ())
29+ device = torch .device ("cuda" , rank % torch .cuda .device_count ())
30+ else :
31+ device = torch .device ("cpu" )
2732 return rank , world_size , device
2833
2934
3035def cleanup_dist ():
3136 """Clean up distributed environment."""
3237 if dist .is_initialized ():
38+ dist .barrier ()
3339 dist .destroy_process_group ()
3440
3541
42+ def make_cyclic_splits (rank , world_size ):
43+ """Return deterministic asymmetric splits valid for any world size."""
44+ send_splits = [((rank + 2 * peer ) % 5 ) + 1 for peer in range (world_size )]
45+ recv_splits = [((peer + 2 * rank ) % 5 ) + 1 for peer in range (world_size )]
46+ return send_splits , recv_splits
47+
48+
49+ def make_encoded_input (rank , send_splits , device ):
50+ """Build rows whose values encode source rank, target rank, and row id."""
51+ rows = []
52+ for peer , count in enumerate (send_splits ):
53+ for row_id in range (count ):
54+ rows .append ([float (rank ), float (peer ), float (row_id )])
55+ return torch .tensor (rows , dtype = torch .float64 , device = device )
56+
57+
58+ def make_expected_encoded_output (rank , world_size , device ):
59+ """Expected all-to-all output for make_encoded_input and make_cyclic_splits."""
60+ rows = []
61+ for source_rank in range (world_size ):
62+ source_send_splits , _ = make_cyclic_splits (source_rank , world_size )
63+ count = source_send_splits [rank ]
64+ for row_id in range (count ):
65+ rows .append ([float (source_rank ), float (rank ), float (row_id )])
66+ return torch .tensor (rows , dtype = torch .float64 , device = device )
67+
68+
3669class TestAllToAllMultiGPU (unittest .TestCase ):
3770 """Multi-GPU tests for _AllToAllDouble communication primitive."""
3871
@@ -44,44 +77,19 @@ def setUpClass(cls):
4477
4578 @classmethod
4679 def tearDownClass (cls ):
47- """Clean up distributed environment."""
48- cleanup_dist ()
80+ """Keep the process group alive until run_tests aggregates results."""
4981
50- def test_forward_shape (self ):
51- """Forward pass should produce correct output shape across ranks."""
52- # Each rank sends different amounts
53- # Constraint: rank i's send_splits[j] == rank j's recv_splits[i]
54- if self .world_size == 2 :
55- send_splits = [3 , 5 ] if self .rank == 0 else [2 , 6 ]
56- recv_splits = [3 , 2 ] if self .rank == 0 else [5 , 6 ]
57- elif self .world_size == 4 :
58- # Matrix: send[i][j] = recv[j][i]
59- # rank 0 sends: [2, 3, 1, 4] -> rank 0 recvs: [2, 5, 3, 7]
60- # rank 1 sends: [5, 2, 4, 3] -> rank 1 recvs: [3, 2, 6, 4]
61- # rank 2 sends: [3, 6, 1, 2] -> rank 2 recvs: [1, 4, 1, 5]
62- # rank 3 sends: [7, 4, 5, 1] -> rank 3 recvs: [4, 3, 2, 1]
63- if self .rank == 0 :
64- send_splits = [2 , 3 , 1 , 4 ]
65- recv_splits = [2 , 5 , 3 , 7 ]
66- elif self .rank == 1 :
67- send_splits = [5 , 2 , 4 , 3 ]
68- recv_splits = [3 , 2 , 6 , 4 ]
69- elif self .rank == 2 :
70- send_splits = [3 , 6 , 1 , 2 ]
71- recv_splits = [1 , 4 , 1 , 5 ]
72- else : # rank 3
73- send_splits = [7 , 4 , 5 , 1 ]
74- recv_splits = [4 , 3 , 2 , 1 ]
75- else :
76- self .skipTest (f"Test not configured for world_size={ self .world_size } " )
82+ def test_forward_values_and_shape (self ):
83+ """Forward pass should move the correct rows across ranks."""
84+ send_splits , recv_splits = make_cyclic_splits (self .rank , self .world_size )
7785
7886 total_send = sum (send_splits )
7987 total_recv = sum (recv_splits )
8088
81- x = torch . randn ( total_send , 8 , device = self .device , requires_grad = True )
89+ x = make_encoded_input ( self . rank , send_splits , self .device ). requires_grad_ ( True )
8290 out = all_to_all_differentiable (x , send_splits , recv_splits , self .group )
91+ expected = make_expected_encoded_output (self .rank , self .world_size , self .device )
8392
84- # Check output shape
8593 self .assertEqual (
8694 out .shape [0 ],
8795 total_recv ,
@@ -92,20 +100,17 @@ def test_forward_shape(self):
92100 x .shape [1 :],
93101 f"Rank { self .rank } : trailing dimensions should be preserved" ,
94102 )
103+ torch .testing .assert_close (out , expected )
95104
96105 def test_backward_no_deadlock (self ):
97106 """Backward pass should not deadlock."""
98- if self .world_size == 2 :
99- send_splits = [4 , 4 ]
100- recv_splits = [4 , 4 ]
101- elif self .world_size == 4 :
102- send_splits = [2 , 2 , 2 , 2 ]
103- recv_splits = [2 , 2 , 2 , 2 ]
104- else :
105- self .skipTest (f"Test not configured for world_size={ self .world_size } " )
107+ send_splits = [2 ] * self .world_size
108+ recv_splits = [2 ] * self .world_size
106109
107110 total_send = sum (send_splits )
108- x = torch .randn (total_send , 8 , device = self .device , requires_grad = True )
111+ x = torch .randn (
112+ total_send , 8 , device = self .device , dtype = torch .float64 , requires_grad = True
113+ )
109114
110115 out = all_to_all_differentiable (x , send_splits , recv_splits , self .group )
111116 loss = (out ** 2 ).sum ()
@@ -120,17 +125,13 @@ def test_backward_no_deadlock(self):
120125
121126 def test_second_backward_no_deadlock (self ):
122127 """Second backward (create_graph=True) should not deadlock."""
123- if self .world_size == 2 :
124- send_splits = [3 , 3 ]
125- recv_splits = [3 , 3 ]
126- elif self .world_size == 4 :
127- send_splits = [2 , 2 , 2 , 2 ]
128- recv_splits = [2 , 2 , 2 , 2 ]
129- else :
130- self .skipTest (f"Test not configured for world_size={ self .world_size } " )
128+ send_splits = [2 ] * self .world_size
129+ recv_splits = [2 ] * self .world_size
131130
132131 total_send = sum (send_splits )
133- x = torch .randn (total_send , 8 , device = self .device , requires_grad = True )
132+ x = torch .randn (
133+ total_send , 8 , device = self .device , dtype = torch .float64 , requires_grad = True
134+ )
134135
135136 # First forward
136137 out = all_to_all_differentiable (x , send_splits , recv_splits , self .group )
@@ -157,36 +158,19 @@ def test_second_backward_no_deadlock(self):
157158
158159 def test_asymmetric_splits (self ):
159160 """Test with asymmetric send/recv splits across ranks."""
160- # Constraint: rank i's send_splits[j] == rank j's recv_splits[i]
161- if self .world_size == 2 :
162- # Rank 0 sends more to rank 1, rank 1 sends more to rank 0
163- send_splits = [2 , 6 ] if self .rank == 0 else [5 , 3 ]
164- recv_splits = [2 , 5 ] if self .rank == 0 else [6 , 3 ]
165- elif self .world_size == 4 :
166- # Matrix: send[i][j] = recv[j][i]
167- # rank 0 sends: [1, 2, 3, 4] -> rank 0 recvs: [1, 3, 2, 4]
168- # rank 1 sends: [3, 2, 1, 4] -> rank 1 recvs: [2, 2, 3, 3]
169- # rank 2 sends: [2, 3, 4, 1] -> rank 2 recvs: [3, 1, 4, 2]
170- # rank 3 sends: [4, 3, 2, 1] -> rank 3 recvs: [4, 4, 1, 1]
171- if self .rank == 0 :
172- send_splits = [1 , 2 , 3 , 4 ]
173- recv_splits = [1 , 3 , 2 , 4 ]
174- elif self .rank == 1 :
175- send_splits = [3 , 2 , 1 , 4 ]
176- recv_splits = [2 , 2 , 3 , 3 ]
177- elif self .rank == 2 :
178- send_splits = [2 , 3 , 4 , 1 ]
179- recv_splits = [3 , 1 , 4 , 2 ]
180- else : # rank 3
181- send_splits = [4 , 3 , 2 , 1 ]
182- recv_splits = [4 , 4 , 1 , 1 ]
183- else :
184- self .skipTest (f"Test not configured for world_size={ self .world_size } " )
161+ send_splits , recv_splits = make_cyclic_splits (self .rank , self .world_size )
162+ self .assertNotEqual (
163+ send_splits ,
164+ recv_splits ,
165+ f"Rank { self .rank } : split pattern should be asymmetric" ,
166+ )
185167
186168 total_send = sum (send_splits )
187169 total_recv = sum (recv_splits )
188170
189- x = torch .randn (total_send , 16 , device = self .device , requires_grad = True )
171+ x = torch .randn (
172+ total_send , 16 , device = self .device , dtype = torch .float64 , requires_grad = True
173+ )
190174 out = all_to_all_differentiable (x , send_splits , recv_splits , self .group )
191175
192176 # Check shape
@@ -198,12 +182,73 @@ def test_asymmetric_splits(self):
198182 loss .backward ()
199183 self .assertIsNotNone (x .grad )
200184
185+ def test_three_layer_second_backward_no_deadlock (self ):
186+ """Three chained A2A ops should support second backward."""
187+ send_splits = [1 ] * self .world_size
188+ recv_splits = [1 ] * self .world_size
189+ x = torch .randn (
190+ self .world_size ,
191+ 4 ,
192+ dtype = torch .float64 ,
193+ device = self .device ,
194+ requires_grad = True ,
195+ )
196+
197+ y = x
198+ for _ in range (3 ):
199+ y = all_to_all_differentiable (y , send_splits , recv_splits , self .group )
200+
201+ loss = (y ** 2 ).sum ()
202+ (grad_x ,) = torch .autograd .grad (loss , x , create_graph = True , retain_graph = True )
203+ (grad_x ** 2 ).sum ().backward ()
204+ self .assertIsNotNone (x .grad , f"Rank { self .rank } : second-order grad missing" )
205+ self .assertTrue (
206+ (x .grad .abs () > 1e-6 ).any (),
207+ f"Rank { self .rank } : second-order grad should be non-zero" ,
208+ )
209+
210+ def test_gradgradcheck_fp64_world_group (self ):
211+ """Gradgradcheck should exercise _AllToAllDouble with WORLD group."""
212+ torch .manual_seed (20260518 )
213+ if self .device .type == "cuda" :
214+ torch .cuda .manual_seed_all (20260518 )
215+
216+ send_splits = [1 ] * self .world_size
217+ recv_splits = [1 ] * self .world_size
218+ x = torch .randn (
219+ self .world_size ,
220+ 2 ,
221+ dtype = torch .float64 ,
222+ device = self .device ,
223+ requires_grad = True ,
224+ )
225+
226+ def func (inp ):
227+ out = all_to_all_differentiable (
228+ inp , send_splits , recv_splits , group = self .group
229+ )
230+ # Pick the row sourced from this rank so per-rank gradgradcheck
231+ # perturbs only the input that can affect the local output.
232+ return out .narrow (0 , self .rank , 1 )
233+
234+ result = torch .autograd .gradgradcheck (
235+ func ,
236+ (x ,),
237+ eps = 1e-6 ,
238+ atol = 1e-4 ,
239+ raise_exception = False ,
240+ )
241+ self .assertTrue (
242+ result ,
243+ f"Rank { self .rank } : distributed gradgradcheck failed" ,
244+ )
245+
201246
202247def run_tests ():
203248 """Run all tests and report results."""
204249 import sys
205250
206- rank , world_size , _ = setup_dist ()
251+ rank , world_size , device = setup_dist ()
207252
208253 # Only rank 0 prints header
209254 if rank == 0 :
@@ -217,18 +262,20 @@ def run_tests():
217262 result = runner .run (suite )
218263
219264 # Synchronize results across ranks (before cleanup)
220- success = torch .tensor ([1 if result .wasSuccessful () else 0 ], dtype = torch .int32 )
265+ success = torch .tensor (
266+ [1 if result .wasSuccessful () else 0 ], dtype = torch .int32 , device = device
267+ )
221268 if dist .is_initialized ():
222269 dist .all_reduce (success , op = dist .ReduceOp .MIN )
223270
224271 if rank == 0 :
225272 if success .item () == 1 :
226273 sys .stdout .write (f"\n { '=' * 70 } \n " )
227- sys .stdout .write (f"✓ All tests passed on all { world_size } ranks\n " )
274+ sys .stdout .write (f"PASS: all tests passed on all { world_size } ranks\n " )
228275 sys .stdout .write (f"{ '=' * 70 } \n \n " )
229276 else :
230277 sys .stdout .write (f"\n { '=' * 70 } \n " )
231- sys .stdout .write ("✗ Tests failed on at least one rank\n " )
278+ sys .stdout .write ("FAIL: tests failed on at least one rank\n " )
232279 sys .stdout .write (f"{ '=' * 70 } \n \n " )
233280
234281 cleanup_dist ()
0 commit comments