Skip to content

Commit fc8a940

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 8b1d688 commit fc8a940

14 files changed

Lines changed: 587 additions & 368 deletions

File tree

tests/pytorch/distributed/test_gtp.py

Lines changed: 282 additions & 129 deletions
Large diffs are not rendered by default.

tests/pytorch/distributed/test_tp_gtp.py

Lines changed: 105 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
# Fixtures
4343
# ---------------------------------------------------------------------------
4444

45+
4546
@pytest.fixture(autouse=True)
4647
def reset_fp8_state():
4748
yield
@@ -61,6 +62,7 @@ def reset_gtp_globals():
6162
# Helpers
6263
# ---------------------------------------------------------------------------
6364

65+
6466
def _free_port() -> int:
6567
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
6668
s.bind(("", 0))
@@ -125,18 +127,21 @@ def _build_groups(rank: int, world_size: int, tp_size: int, gtp_size: int):
125127
# 1. TestTPGTPProcessGroups – group sizes and rank membership
126128
# ---------------------------------------------------------------------------
127129

130+
128131
def _worker_groups(rank, world_size, port, tp_size, gtp_size):
129132
_dist_init(rank, world_size, port)
130133
tp_group, gtp_group, tp_rank, gtp_rank = _build_groups(rank, world_size, tp_size, gtp_size)
131134

132-
assert tp_group.size() == tp_size, \
133-
f"rank {rank}: TP group size {tp_group.size()} != {tp_size}"
134-
assert gtp_group.size() == gtp_size, \
135-
f"rank {rank}: GTP group size {gtp_group.size()} != {gtp_size}"
136-
assert dist.get_rank(tp_group) == tp_rank, \
137-
f"rank {rank}: TP rank {dist.get_rank(tp_group)} != expected {tp_rank}"
138-
assert dist.get_rank(gtp_group) == gtp_rank, \
139-
f"rank {rank}: GTP rank {dist.get_rank(gtp_group)} != expected {gtp_rank}"
135+
assert tp_group.size() == tp_size, f"rank {rank}: TP group size {tp_group.size()} != {tp_size}"
136+
assert (
137+
gtp_group.size() == gtp_size
138+
), f"rank {rank}: GTP group size {gtp_group.size()} != {gtp_size}"
139+
assert (
140+
dist.get_rank(tp_group) == tp_rank
141+
), f"rank {rank}: TP rank {dist.get_rank(tp_group)} != expected {tp_rank}"
142+
assert (
143+
dist.get_rank(gtp_group) == gtp_rank
144+
), f"rank {rank}: GTP rank {dist.get_rank(gtp_group)} != expected {gtp_rank}"
140145

141146
dist.destroy_process_group()
142147

@@ -153,25 +158,34 @@ def test_group_sizes_and_ranks(self, tp_size, gtp_size):
153158
# 2. TestTPGTPColumnParallelLinear
154159
# ---------------------------------------------------------------------------
155160

161+
156162
def _worker_column_shape(rank, world_size, port, tp_size, gtp_size):
157163
"""Column-parallel: weight shape must be [out_f/(tp_size*gtp_size), in_f]."""
158164
_dist_init(rank, world_size, port)
159165
tp_group, gtp_group, _, _ = _build_groups(rank, world_size, tp_size, gtp_size)
160166

161167
in_f = 64
162-
out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows
168+
out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows
163169

164170
layer = te.Linear(
165-
in_features=in_f, out_features=out_f,
166-
parallel_mode="column", bias=False, params_dtype=torch.bfloat16,
167-
device="cuda", tp_group=tp_group, gtp_group=gtp_group,
171+
in_features=in_f,
172+
out_features=out_f,
173+
parallel_mode="column",
174+
bias=False,
175+
params_dtype=torch.bfloat16,
176+
device="cuda",
177+
tp_group=tp_group,
178+
gtp_group=gtp_group,
168179
)
169180

170181
expected_rows = out_f // (tp_size * gtp_size)
171-
assert isinstance(layer.weight, GTPShardedParam), \
172-
f"rank {rank}: weight should be GTPShardedParam"
173-
assert layer.weight.shape == (expected_rows, in_f), \
174-
f"rank {rank}: expected ({expected_rows}, {in_f}), got {layer.weight.shape}"
182+
assert isinstance(
183+
layer.weight, GTPShardedParam
184+
), f"rank {rank}: weight should be GTPShardedParam"
185+
assert layer.weight.shape == (
186+
expected_rows,
187+
in_f,
188+
), f"rank {rank}: expected ({expected_rows}, {in_f}), got {layer.weight.shape}"
175189

176190
dist.destroy_process_group()
177191

@@ -183,21 +197,26 @@ def _worker_column_correctness(rank, world_size, port, tp_size, gtp_size):
183197
tp_group, gtp_group, tp_rank, gtp_rank = _build_groups(rank, world_size, tp_size, gtp_size)
184198

185199
batch, in_f = 16, 64
186-
out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows
200+
out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows
187201
dtype = torch.bfloat16
188202

189203
layer = te.Linear(
190-
in_features=in_f, out_features=out_f,
191-
parallel_mode="column", bias=False, params_dtype=dtype,
192-
device="cuda", tp_group=tp_group, gtp_group=gtp_group,
204+
in_features=in_f,
205+
out_features=out_f,
206+
parallel_mode="column",
207+
bias=False,
208+
params_dtype=dtype,
209+
device="cuda",
210+
tp_group=tp_group,
211+
gtp_group=gtp_group,
193212
)
194213

195214
# All-gather GTP shards → TP-local full weight [out_f/tp_size, in_f]
196215
shard = layer.weight.data.clone()
197216
all_gtp_shards = [torch.zeros_like(shard) for _ in range(gtp_size)]
198217
dist.all_gather(all_gtp_shards, shard, group=gtp_group)
199218
tp_local_weight = torch.cat(all_gtp_shards, dim=0).float() # strip padding
200-
tp_local_weight = tp_local_weight[:out_f // tp_size]
219+
tp_local_weight = tp_local_weight[: out_f // tp_size]
201220

202221
# Same full input on all ranks (column-parallel: each rank processes full input)
203222
inp = torch.randn(batch, in_f, dtype=dtype, device="cuda")
@@ -206,16 +225,17 @@ def _worker_column_correctness(rank, world_size, port, tp_size, gtp_size):
206225

207226
# TE forward: GTP all-gathers weight internally; no TP comm in column-parallel fwd
208227
out = layer(inp_te, is_first_microbatch=True)
209-
assert out.shape == (batch, out_f // tp_size), \
210-
f"rank {rank}: output shape {out.shape} != ({batch}, {out_f // tp_size})"
228+
assert out.shape == (
229+
batch,
230+
out_f // tp_size,
231+
), f"rank {rank}: output shape {out.shape} != ({batch}, {out_f // tp_size})"
211232

212233
# Reference: this TP rank's output = inp @ tp_local_weight^T
213234
ref = inp.float() @ tp_local_weight.T
214235
ref = ref.to(dtype)
215-
assert torch.allclose(out.float(), ref.float(), atol=0.1, rtol=0.1), (
216-
f"rank {rank}: output mismatch, "
217-
f"max_diff={(out.float() - ref.float()).abs().max():.4f}"
218-
)
236+
assert torch.allclose(
237+
out.float(), ref.float(), atol=0.1, rtol=0.1
238+
), f"rank {rank}: output mismatch, max_diff={(out.float() - ref.float()).abs().max():.4f}"
219239

220240
# Backward: dX is all-reduced across TP group internally by TE
221241
grad = torch.randn_like(out)
@@ -247,25 +267,33 @@ def test_forward_backward_correctness(self, tp_size, gtp_size):
247267
# 3. TestTPGTPRowParallelLinear
248268
# ---------------------------------------------------------------------------
249269

270+
250271
def _worker_row_shape(rank, world_size, port, tp_size, gtp_size):
251272
"""Row-parallel: weight shape must be [out_f/gtp_size, in_f/tp_size]."""
252273
_dist_init(rank, world_size, port)
253274
tp_group, gtp_group, _, _ = _build_groups(rank, world_size, tp_size, gtp_size)
254275

255-
in_f = tp_size * 64 # TE divides by tp_size → local in_f = 64
276+
in_f = tp_size * 64 # TE divides by tp_size → local in_f = 64
256277
out_f = gtp_size * 64 # GTP divides by gtp_size → local out_f = 64
257278

258279
layer = te.Linear(
259-
in_features=in_f, out_features=out_f,
260-
parallel_mode="row", bias=False, params_dtype=torch.bfloat16,
261-
device="cuda", tp_group=tp_group, gtp_group=gtp_group,
280+
in_features=in_f,
281+
out_features=out_f,
282+
parallel_mode="row",
283+
bias=False,
284+
params_dtype=torch.bfloat16,
285+
device="cuda",
286+
tp_group=tp_group,
287+
gtp_group=gtp_group,
262288
)
263289

264290
expected_shape = (out_f // gtp_size, in_f // tp_size)
265-
assert isinstance(layer.weight, GTPShardedParam), \
266-
f"rank {rank}: weight should be GTPShardedParam"
267-
assert layer.weight.shape == expected_shape, \
268-
f"rank {rank}: expected {expected_shape}, got {layer.weight.shape}"
291+
assert isinstance(
292+
layer.weight, GTPShardedParam
293+
), f"rank {rank}: weight should be GTPShardedParam"
294+
assert (
295+
layer.weight.shape == expected_shape
296+
), f"rank {rank}: expected {expected_shape}, got {layer.weight.shape}"
269297

270298
dist.destroy_process_group()
271299

@@ -277,14 +305,19 @@ def _worker_row_forward_backward(rank, world_size, port, tp_size, gtp_size):
277305
tp_group, gtp_group, tp_rank, _ = _build_groups(rank, world_size, tp_size, gtp_size)
278306

279307
batch = 16
280-
in_f = tp_size * 64 # full in_features
308+
in_f = tp_size * 64 # full in_features
281309
out_f = gtp_size * 64 # full out_features
282310
dtype = torch.bfloat16
283311

284312
layer = te.Linear(
285-
in_features=in_f, out_features=out_f,
286-
parallel_mode="row", bias=False, params_dtype=dtype,
287-
device="cuda", tp_group=tp_group, gtp_group=gtp_group,
313+
in_features=in_f,
314+
out_features=out_f,
315+
parallel_mode="row",
316+
bias=False,
317+
params_dtype=dtype,
318+
device="cuda",
319+
tp_group=tp_group,
320+
gtp_group=gtp_group,
288321
)
289322

290323
# Row-parallel: each TP rank takes the corresponding slice of in_f
@@ -296,8 +329,10 @@ def _worker_row_forward_backward(rank, world_size, port, tp_size, gtp_size):
296329

297330
# TE forward: GTP all-gathers weight, row-parallel all-reduces output across TP
298331
out = layer(inp, is_first_microbatch=True)
299-
assert out.shape == (batch, out_f), \
300-
f"rank {rank}: output shape {out.shape} != ({batch}, {out_f})"
332+
assert out.shape == (
333+
batch,
334+
out_f,
335+
), f"rank {rank}: output shape {out.shape} != ({batch}, {out_f})"
301336
assert torch.isfinite(out).all(), f"rank {rank}: non-finite output"
302337

303338
# wgrad RS path always accumulates into main_grad; allocate before backward.
@@ -321,20 +356,25 @@ def _worker_row_correctness(rank, world_size, port, tp_size, gtp_size):
321356
dtype = torch.bfloat16
322357

323358
layer = te.Linear(
324-
in_features=in_f, out_features=out_f,
325-
parallel_mode="row", bias=False, params_dtype=dtype,
326-
device="cuda", tp_group=tp_group, gtp_group=gtp_group,
359+
in_features=in_f,
360+
out_features=out_f,
361+
parallel_mode="row",
362+
bias=False,
363+
params_dtype=dtype,
364+
device="cuda",
365+
tp_group=tp_group,
366+
gtp_group=gtp_group,
327367
)
328368

329369
# Reconstruct full weight: all-gather GTP shards → TP-local, then all-gather TP shards
330370
shard = layer.weight.data.clone()
331371
all_gtp_shards = [torch.zeros_like(shard) for _ in range(gtp_size)]
332372
dist.all_gather(all_gtp_shards, shard, group=gtp_group)
333-
tp_local_weight = torch.cat(all_gtp_shards, dim=0).float() # [out_f, in_f/tp_size]
373+
tp_local_weight = torch.cat(all_gtp_shards, dim=0).float() # [out_f, in_f/tp_size]
334374

335375
all_tp_weights = [torch.zeros_like(tp_local_weight) for _ in range(tp_size)]
336376
dist.all_gather(all_tp_weights, tp_local_weight, group=tp_group)
337-
full_weight = torch.cat(all_tp_weights, dim=1).float() # [out_f, in_f]
377+
full_weight = torch.cat(all_tp_weights, dim=1).float() # [out_f, in_f]
338378

339379
# Full input (same on all ranks; we slice below to simulate row-parallel)
340380
full_inp = torch.randn(batch, in_f, dtype=dtype, device="cuda")
@@ -348,10 +388,9 @@ def _worker_row_correctness(rank, world_size, port, tp_size, gtp_size):
348388
# Reference: full input @ full weight^T — all ranks should see the same output
349389
ref = full_inp.float() @ full_weight.T
350390
ref = ref.to(dtype)
351-
assert torch.allclose(out.float(), ref.float(), atol=0.1, rtol=0.1), (
352-
f"rank {rank}: output mismatch, "
353-
f"max_diff={(out.float() - ref.float()).abs().max():.4f}"
354-
)
391+
assert torch.allclose(
392+
out.float(), ref.float(), atol=0.1, rtol=0.1
393+
), f"rank {rank}: output mismatch, max_diff={(out.float() - ref.float()).abs().max():.4f}"
355394

356395
dist.destroy_process_group()
357396

@@ -380,6 +419,7 @@ def test_forward_correctness(self, tp_size, gtp_size):
380419
# 4. TestTPGTPLayerNormLinear – column-parallel smoke test
381420
# ---------------------------------------------------------------------------
382421

422+
383423
def _worker_layernorm_linear(rank, world_size, port, tp_size, gtp_size):
384424
_dist_init(rank, world_size, port)
385425
torch.manual_seed(0)
@@ -391,23 +431,29 @@ def _worker_layernorm_linear(rank, world_size, port, tp_size, gtp_size):
391431
dtype = torch.bfloat16
392432

393433
layer = te.LayerNormLinear(
394-
in_features=in_f, out_features=out_f,
395-
bias=False, params_dtype=dtype,
434+
in_features=in_f,
435+
out_features=out_f,
436+
bias=False,
437+
params_dtype=dtype,
396438
parallel_mode="column",
397-
device="cuda", tp_group=tp_group, gtp_group=gtp_group,
439+
device="cuda",
440+
tp_group=tp_group,
441+
gtp_group=gtp_group,
398442
)
399-
assert isinstance(layer.weight, GTPShardedParam), \
400-
f"rank {rank}: LayerNormLinear.weight should be GTPShardedParam"
443+
assert isinstance(
444+
layer.weight, GTPShardedParam
445+
), f"rank {rank}: LayerNormLinear.weight should be GTPShardedParam"
401446
expected_rows = out_f // (tp_size * gtp_size)
402-
assert layer.weight.shape == (expected_rows, in_f), \
403-
f"rank {rank}: unexpected weight shape {layer.weight.shape}"
447+
assert layer.weight.shape == (
448+
expected_rows,
449+
in_f,
450+
), f"rank {rank}: unexpected weight shape {layer.weight.shape}"
404451

405452
inp = torch.randn(seq, batch, in_f, dtype=dtype, device="cuda", requires_grad=True)
406453
dist.broadcast(inp, src=0)
407454

408455
out = layer(inp, is_first_microbatch=True)
409-
assert out.shape == (seq, batch, out_f // tp_size), \
410-
f"rank {rank}: output shape {out.shape}"
456+
assert out.shape == (seq, batch, out_f // tp_size), f"rank {rank}: output shape {out.shape}"
411457
assert torch.isfinite(out).all(), f"rank {rank}: non-finite output"
412458

413459
# wgrad RS path always accumulates into main_grad; allocate before backward.

transformer_engine/common/include/transformer_engine/recipe.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output,
116116
* \param[in] config Quantization configuration (for noop_tensor). May be NULL.
117117
* \param[in] stream CUDA stream used for the operation.
118118
*/
119-
void nvte_multi_compute_amax(const NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors,
119+
void nvte_multi_compute_amax(const NVTETensor* inputs, NVTETensor* outputs, size_t num_tensors,
120120
const NVTEQuantizationConfig config, cudaStream_t stream);
121121

122122
/*! \brief Update an FP8 tensor's scale based on its amax.

transformer_engine/common/recipe/multi_amax.cu

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ __launch_bounds__(multi_amax_kernel_threads) __global__
8181
InputType max = InputType{0.f};
8282
const int warp_id = threadIdx.x / THREADS_PER_WARP;
8383

84-
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M;
85-
tid += gridDim.x * blockDim.x) {
84+
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
8685
loader.load(tid, N);
8786
#pragma unroll
8887
for (int i = 0; i < nvec; ++i) {
@@ -146,18 +145,15 @@ void launch_multi_amax_batch(const MultiAmaxArgs &args, size_t max_numel, Alignm
146145

147146
switch (align) {
148147
case Alignment::SAME_ALIGNED:
149-
MultiAmaxKernel<nvec, true, InputType>
150-
<<<grid, threads, 0, stream>>>(args, noop_ptr);
148+
MultiAmaxKernel<nvec, true, InputType><<<grid, threads, 0, stream>>>(args, noop_ptr);
151149
break;
152150
case Alignment::SAME_UNALIGNED:
153-
MultiAmaxKernel<nvec, false, InputType>
154-
<<<grid, threads, 0, stream>>>(args, noop_ptr);
151+
MultiAmaxKernel<nvec, false, InputType><<<grid, threads, 0, stream>>>(args, noop_ptr);
155152
break;
156153
case Alignment::DIFFERENT:
157154
// Heterogeneous alignment across tensors — fall back to nvec=1, aligned=true path
158155
// which is safe for any pointer alignment.
159-
MultiAmaxKernel<1, true, InputType>
160-
<<<grid, threads, 0, stream>>>(args, noop_ptr);
156+
MultiAmaxKernel<1, true, InputType><<<grid, threads, 0, stream>>>(args, noop_ptr);
161157
break;
162158
}
163159
NVTE_CHECK_CUDA(cudaGetLastError());
@@ -186,8 +182,8 @@ std::pair<size_t, Alignment> build_batch_args(const std::vector<Tensor *> &input
186182
args.output_rowwise_amax_list[i] = rw_ptr;
187183
args.output_columnwise_amax_list[i] = cw_ptr;
188184
args.input_numel[i] = N;
189-
args.num_aligned_elements[i] = get_num_aligned_elements(inp.data.dptr, N, nvec,
190-
sizeof(InputType));
185+
args.num_aligned_elements[i] =
186+
get_num_aligned_elements(inp.data.dptr, N, nvec, sizeof(InputType));
191187
max_numel = std::max(max_numel, N);
192188

193189
// Fold this tensor's alignment into the batch decision. CheckAlignment on a
@@ -225,11 +221,9 @@ void multi_compute_amax_impl(const NVTETensor *inputs_, NVTETensor *outputs_, si
225221
outputs[i] = convertNVTETensorCheck(outputs_[i]);
226222
const auto &inp = *inputs[i];
227223
auto &out = *outputs[i];
228-
NVTE_CHECK(inp.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
229-
"nvte_multi_compute_amax: input[", i,
230-
"] must be unquantized, got scaling_mode=", to_string(inp.scaling_mode));
231-
NVTE_CHECK(!is_fp8_dtype(inp.data.dtype),
232-
"nvte_multi_compute_amax: input[", i,
224+
NVTE_CHECK(inp.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "nvte_multi_compute_amax: input[",
225+
i, "] must be unquantized, got scaling_mode=", to_string(inp.scaling_mode));
226+
NVTE_CHECK(!is_fp8_dtype(inp.data.dtype), "nvte_multi_compute_amax: input[", i,
233227
"] must be unquantized, got dtype=", to_string(inp.data.dtype));
234228
if (i == 0) {
235229
input_dtype = inp.data.dtype;

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,7 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
333333
py::object compute_amax_nvfp4(const at::Tensor &tensor, py::handle quantizer,
334334
const py::object &output);
335335
py::object quantize_cast_only_nvfp4(const at::Tensor &tensor, py::handle quantizer,
336-
const py::object &output,
337-
std::optional<at::Tensor> noop_flag);
336+
const py::object &output, std::optional<at::Tensor> noop_flag);
338337

339338
// NVFP4-only multi-tensor amax: fuses N per-expert (zero_amax + amax + D2D replicate)
340339
// chains into a single pair of kernel launches (one multi-zero + one multi-amax) that

0 commit comments

Comments
 (0)