4242# Fixtures
4343# ---------------------------------------------------------------------------
4444
45+
4546@pytest .fixture (autouse = True )
4647def reset_fp8_state ():
4748 yield
@@ -61,6 +62,7 @@ def reset_gtp_globals():
6162# Helpers
6263# ---------------------------------------------------------------------------
6364
65+
6466def _free_port () -> int :
6567 with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
6668 s .bind (("" , 0 ))
@@ -125,18 +127,21 @@ def _build_groups(rank: int, world_size: int, tp_size: int, gtp_size: int):
125127# 1. TestTPGTPProcessGroups – group sizes and rank membership
126128# ---------------------------------------------------------------------------
127129
130+
128131def _worker_groups (rank , world_size , port , tp_size , gtp_size ):
129132 _dist_init (rank , world_size , port )
130133 tp_group , gtp_group , tp_rank , gtp_rank = _build_groups (rank , world_size , tp_size , gtp_size )
131134
132- assert tp_group .size () == tp_size , \
133- f"rank { rank } : TP group size { tp_group .size ()} != { tp_size } "
134- assert gtp_group .size () == gtp_size , \
135- f"rank { rank } : GTP group size { gtp_group .size ()} != { gtp_size } "
136- assert dist .get_rank (tp_group ) == tp_rank , \
137- f"rank { rank } : TP rank { dist .get_rank (tp_group )} != expected { tp_rank } "
138- assert dist .get_rank (gtp_group ) == gtp_rank , \
139- f"rank { rank } : GTP rank { dist .get_rank (gtp_group )} != expected { gtp_rank } "
135+ assert tp_group .size () == tp_size , f"rank { rank } : TP group size { tp_group .size ()} != { tp_size } "
136+ assert (
137+ gtp_group .size () == gtp_size
138+ ), f"rank { rank } : GTP group size { gtp_group .size ()} != { gtp_size } "
139+ assert (
140+ dist .get_rank (tp_group ) == tp_rank
141+ ), f"rank { rank } : TP rank { dist .get_rank (tp_group )} != expected { tp_rank } "
142+ assert (
143+ dist .get_rank (gtp_group ) == gtp_rank
144+ ), f"rank { rank } : GTP rank { dist .get_rank (gtp_group )} != expected { gtp_rank } "
140145
141146 dist .destroy_process_group ()
142147
@@ -153,25 +158,34 @@ def test_group_sizes_and_ranks(self, tp_size, gtp_size):
153158# 2. TestTPGTPColumnParallelLinear
154159# ---------------------------------------------------------------------------
155160
161+
156162def _worker_column_shape (rank , world_size , port , tp_size , gtp_size ):
157163 """Column-parallel: weight shape must be [out_f/(tp_size*gtp_size), in_f]."""
158164 _dist_init (rank , world_size , port )
159165 tp_group , gtp_group , _ , _ = _build_groups (rank , world_size , tp_size , gtp_size )
160166
161167 in_f = 64
162- out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows
168+ out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows
163169
164170 layer = te .Linear (
165- in_features = in_f , out_features = out_f ,
166- parallel_mode = "column" , bias = False , params_dtype = torch .bfloat16 ,
167- device = "cuda" , tp_group = tp_group , gtp_group = gtp_group ,
171+ in_features = in_f ,
172+ out_features = out_f ,
173+ parallel_mode = "column" ,
174+ bias = False ,
175+ params_dtype = torch .bfloat16 ,
176+ device = "cuda" ,
177+ tp_group = tp_group ,
178+ gtp_group = gtp_group ,
168179 )
169180
170181 expected_rows = out_f // (tp_size * gtp_size )
171- assert isinstance (layer .weight , GTPShardedParam ), \
172- f"rank { rank } : weight should be GTPShardedParam"
173- assert layer .weight .shape == (expected_rows , in_f ), \
174- f"rank { rank } : expected ({ expected_rows } , { in_f } ), got { layer .weight .shape } "
182+ assert isinstance (
183+ layer .weight , GTPShardedParam
184+ ), f"rank { rank } : weight should be GTPShardedParam"
185+ assert layer .weight .shape == (
186+ expected_rows ,
187+ in_f ,
188+ ), f"rank { rank } : expected ({ expected_rows } , { in_f } ), got { layer .weight .shape } "
175189
176190 dist .destroy_process_group ()
177191
@@ -183,21 +197,26 @@ def _worker_column_correctness(rank, world_size, port, tp_size, gtp_size):
183197 tp_group , gtp_group , tp_rank , gtp_rank = _build_groups (rank , world_size , tp_size , gtp_size )
184198
185199 batch , in_f = 16 , 64
186- out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows
200+ out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows
187201 dtype = torch .bfloat16
188202
189203 layer = te .Linear (
190- in_features = in_f , out_features = out_f ,
191- parallel_mode = "column" , bias = False , params_dtype = dtype ,
192- device = "cuda" , tp_group = tp_group , gtp_group = gtp_group ,
204+ in_features = in_f ,
205+ out_features = out_f ,
206+ parallel_mode = "column" ,
207+ bias = False ,
208+ params_dtype = dtype ,
209+ device = "cuda" ,
210+ tp_group = tp_group ,
211+ gtp_group = gtp_group ,
193212 )
194213
195214 # All-gather GTP shards → TP-local full weight [out_f/tp_size, in_f]
196215 shard = layer .weight .data .clone ()
197216 all_gtp_shards = [torch .zeros_like (shard ) for _ in range (gtp_size )]
198217 dist .all_gather (all_gtp_shards , shard , group = gtp_group )
199218 tp_local_weight = torch .cat (all_gtp_shards , dim = 0 ).float () # strip padding
200- tp_local_weight = tp_local_weight [:out_f // tp_size ]
219+ tp_local_weight = tp_local_weight [: out_f // tp_size ]
201220
202221 # Same full input on all ranks (column-parallel: each rank processes full input)
203222 inp = torch .randn (batch , in_f , dtype = dtype , device = "cuda" )
@@ -206,16 +225,17 @@ def _worker_column_correctness(rank, world_size, port, tp_size, gtp_size):
206225
207226 # TE forward: GTP all-gathers weight internally; no TP comm in column-parallel fwd
208227 out = layer (inp_te , is_first_microbatch = True )
209- assert out .shape == (batch , out_f // tp_size ), \
210- f"rank { rank } : output shape { out .shape } != ({ batch } , { out_f // tp_size } )"
228+ assert out .shape == (
229+ batch ,
230+ out_f // tp_size ,
231+ ), f"rank { rank } : output shape { out .shape } != ({ batch } , { out_f // tp_size } )"
211232
212233 # Reference: this TP rank's output = inp @ tp_local_weight^T
213234 ref = inp .float () @ tp_local_weight .T
214235 ref = ref .to (dtype )
215- assert torch .allclose (out .float (), ref .float (), atol = 0.1 , rtol = 0.1 ), (
216- f"rank { rank } : output mismatch, "
217- f"max_diff={ (out .float () - ref .float ()).abs ().max ():.4f} "
218- )
236+ assert torch .allclose (
237+ out .float (), ref .float (), atol = 0.1 , rtol = 0.1
238+ ), f"rank { rank } : output mismatch, max_diff={ (out .float () - ref .float ()).abs ().max ():.4f} "
219239
220240 # Backward: dX is all-reduced across TP group internally by TE
221241 grad = torch .randn_like (out )
@@ -247,25 +267,33 @@ def test_forward_backward_correctness(self, tp_size, gtp_size):
247267# 3. TestTPGTPRowParallelLinear
248268# ---------------------------------------------------------------------------
249269
270+
250271def _worker_row_shape (rank , world_size , port , tp_size , gtp_size ):
251272 """Row-parallel: weight shape must be [out_f/gtp_size, in_f/tp_size]."""
252273 _dist_init (rank , world_size , port )
253274 tp_group , gtp_group , _ , _ = _build_groups (rank , world_size , tp_size , gtp_size )
254275
255- in_f = tp_size * 64 # TE divides by tp_size → local in_f = 64
276+ in_f = tp_size * 64 # TE divides by tp_size → local in_f = 64
256277 out_f = gtp_size * 64 # GTP divides by gtp_size → local out_f = 64
257278
258279 layer = te .Linear (
259- in_features = in_f , out_features = out_f ,
260- parallel_mode = "row" , bias = False , params_dtype = torch .bfloat16 ,
261- device = "cuda" , tp_group = tp_group , gtp_group = gtp_group ,
280+ in_features = in_f ,
281+ out_features = out_f ,
282+ parallel_mode = "row" ,
283+ bias = False ,
284+ params_dtype = torch .bfloat16 ,
285+ device = "cuda" ,
286+ tp_group = tp_group ,
287+ gtp_group = gtp_group ,
262288 )
263289
264290 expected_shape = (out_f // gtp_size , in_f // tp_size )
265- assert isinstance (layer .weight , GTPShardedParam ), \
266- f"rank { rank } : weight should be GTPShardedParam"
267- assert layer .weight .shape == expected_shape , \
268- f"rank { rank } : expected { expected_shape } , got { layer .weight .shape } "
291+ assert isinstance (
292+ layer .weight , GTPShardedParam
293+ ), f"rank { rank } : weight should be GTPShardedParam"
294+ assert (
295+ layer .weight .shape == expected_shape
296+ ), f"rank { rank } : expected { expected_shape } , got { layer .weight .shape } "
269297
270298 dist .destroy_process_group ()
271299
@@ -277,14 +305,19 @@ def _worker_row_forward_backward(rank, world_size, port, tp_size, gtp_size):
277305 tp_group , gtp_group , tp_rank , _ = _build_groups (rank , world_size , tp_size , gtp_size )
278306
279307 batch = 16
280- in_f = tp_size * 64 # full in_features
308+ in_f = tp_size * 64 # full in_features
281309 out_f = gtp_size * 64 # full out_features
282310 dtype = torch .bfloat16
283311
284312 layer = te .Linear (
285- in_features = in_f , out_features = out_f ,
286- parallel_mode = "row" , bias = False , params_dtype = dtype ,
287- device = "cuda" , tp_group = tp_group , gtp_group = gtp_group ,
313+ in_features = in_f ,
314+ out_features = out_f ,
315+ parallel_mode = "row" ,
316+ bias = False ,
317+ params_dtype = dtype ,
318+ device = "cuda" ,
319+ tp_group = tp_group ,
320+ gtp_group = gtp_group ,
288321 )
289322
290323 # Row-parallel: each TP rank takes the corresponding slice of in_f
@@ -296,8 +329,10 @@ def _worker_row_forward_backward(rank, world_size, port, tp_size, gtp_size):
296329
297330 # TE forward: GTP all-gathers weight, row-parallel all-reduces output across TP
298331 out = layer (inp , is_first_microbatch = True )
299- assert out .shape == (batch , out_f ), \
300- f"rank { rank } : output shape { out .shape } != ({ batch } , { out_f } )"
332+ assert out .shape == (
333+ batch ,
334+ out_f ,
335+ ), f"rank { rank } : output shape { out .shape } != ({ batch } , { out_f } )"
301336 assert torch .isfinite (out ).all (), f"rank { rank } : non-finite output"
302337
303338 # wgrad RS path always accumulates into main_grad; allocate before backward.
@@ -321,20 +356,25 @@ def _worker_row_correctness(rank, world_size, port, tp_size, gtp_size):
321356 dtype = torch .bfloat16
322357
323358 layer = te .Linear (
324- in_features = in_f , out_features = out_f ,
325- parallel_mode = "row" , bias = False , params_dtype = dtype ,
326- device = "cuda" , tp_group = tp_group , gtp_group = gtp_group ,
359+ in_features = in_f ,
360+ out_features = out_f ,
361+ parallel_mode = "row" ,
362+ bias = False ,
363+ params_dtype = dtype ,
364+ device = "cuda" ,
365+ tp_group = tp_group ,
366+ gtp_group = gtp_group ,
327367 )
328368
329369 # Reconstruct full weight: all-gather GTP shards → TP-local, then all-gather TP shards
330370 shard = layer .weight .data .clone ()
331371 all_gtp_shards = [torch .zeros_like (shard ) for _ in range (gtp_size )]
332372 dist .all_gather (all_gtp_shards , shard , group = gtp_group )
333- tp_local_weight = torch .cat (all_gtp_shards , dim = 0 ).float () # [out_f, in_f/tp_size]
373+ tp_local_weight = torch .cat (all_gtp_shards , dim = 0 ).float () # [out_f, in_f/tp_size]
334374
335375 all_tp_weights = [torch .zeros_like (tp_local_weight ) for _ in range (tp_size )]
336376 dist .all_gather (all_tp_weights , tp_local_weight , group = tp_group )
337- full_weight = torch .cat (all_tp_weights , dim = 1 ).float () # [out_f, in_f]
377+ full_weight = torch .cat (all_tp_weights , dim = 1 ).float () # [out_f, in_f]
338378
339379 # Full input (same on all ranks; we slice below to simulate row-parallel)
340380 full_inp = torch .randn (batch , in_f , dtype = dtype , device = "cuda" )
@@ -348,10 +388,9 @@ def _worker_row_correctness(rank, world_size, port, tp_size, gtp_size):
348388 # Reference: full input @ full weight^T — all ranks should see the same output
349389 ref = full_inp .float () @ full_weight .T
350390 ref = ref .to (dtype )
351- assert torch .allclose (out .float (), ref .float (), atol = 0.1 , rtol = 0.1 ), (
352- f"rank { rank } : output mismatch, "
353- f"max_diff={ (out .float () - ref .float ()).abs ().max ():.4f} "
354- )
391+ assert torch .allclose (
392+ out .float (), ref .float (), atol = 0.1 , rtol = 0.1
393+ ), f"rank { rank } : output mismatch, max_diff={ (out .float () - ref .float ()).abs ().max ():.4f} "
355394
356395 dist .destroy_process_group ()
357396
@@ -380,6 +419,7 @@ def test_forward_correctness(self, tp_size, gtp_size):
380419# 4. TestTPGTPLayerNormLinear – column-parallel smoke test
381420# ---------------------------------------------------------------------------
382421
422+
383423def _worker_layernorm_linear (rank , world_size , port , tp_size , gtp_size ):
384424 _dist_init (rank , world_size , port )
385425 torch .manual_seed (0 )
@@ -391,23 +431,29 @@ def _worker_layernorm_linear(rank, world_size, port, tp_size, gtp_size):
391431 dtype = torch .bfloat16
392432
393433 layer = te .LayerNormLinear (
394- in_features = in_f , out_features = out_f ,
395- bias = False , params_dtype = dtype ,
434+ in_features = in_f ,
435+ out_features = out_f ,
436+ bias = False ,
437+ params_dtype = dtype ,
396438 parallel_mode = "column" ,
397- device = "cuda" , tp_group = tp_group , gtp_group = gtp_group ,
439+ device = "cuda" ,
440+ tp_group = tp_group ,
441+ gtp_group = gtp_group ,
398442 )
399- assert isinstance (layer .weight , GTPShardedParam ), \
400- f"rank { rank } : LayerNormLinear.weight should be GTPShardedParam"
443+ assert isinstance (
444+ layer .weight , GTPShardedParam
445+ ), f"rank { rank } : LayerNormLinear.weight should be GTPShardedParam"
401446 expected_rows = out_f // (tp_size * gtp_size )
402- assert layer .weight .shape == (expected_rows , in_f ), \
403- f"rank { rank } : unexpected weight shape { layer .weight .shape } "
447+ assert layer .weight .shape == (
448+ expected_rows ,
449+ in_f ,
450+ ), f"rank { rank } : unexpected weight shape { layer .weight .shape } "
404451
405452 inp = torch .randn (seq , batch , in_f , dtype = dtype , device = "cuda" , requires_grad = True )
406453 dist .broadcast (inp , src = 0 )
407454
408455 out = layer (inp , is_first_microbatch = True )
409- assert out .shape == (seq , batch , out_f // tp_size ), \
410- f"rank { rank } : output shape { out .shape } "
456+ assert out .shape == (seq , batch , out_f // tp_size ), f"rank { rank } : output shape { out .shape } "
411457 assert torch .isfinite (out ).all (), f"rank { rank } : non-finite output"
412458
413459 # wgrad RS path always accumulates into main_grad; allocate before backward.
0 commit comments