Skip to content

Commit 6f17fe4

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 8e95a19 commit 6f17fe4

7 files changed

Lines changed: 362 additions & 412 deletions

File tree

transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token.cu

Lines changed: 184 additions & 203 deletions
Large diffs are not rendered by default.

transformer_engine/common/cast/nvfp4/quantize_nvfp4_per_token_group.cu

Lines changed: 44 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,14 @@ constexpr int kMaxTensorsPerKernel = 64;
5757
// ============================================================================
5858
struct NVFP4PerTokenMultiArgs {
5959
// K1 outputs (per-tensor pointers; one fp32 array per tensor)
60-
void* row_amax_list[kMaxTensorsPerKernel]; // each: float* (M_i,)
61-
void* col_amax_list[kMaxTensorsPerKernel]; // each: float* (K,)
60+
void* row_amax_list[kMaxTensorsPerKernel]; // each: float* (M_i,)
61+
void* col_amax_list[kMaxTensorsPerKernel]; // each: float* (K,)
6262

6363
// K2 outputs (per-tensor pointers; FP4 codes + e4m3 inner SF)
64-
void* q_row_list[kMaxTensorsPerKernel]; // each: uint8* (M_i, K/2)
65-
void* s_dec_row_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (M_i, K/16)
66-
void* q_col_list[kMaxTensorsPerKernel]; // each: uint8* (K, M_i/2)
67-
void* s_dec_col_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (K, M_i/16)
64+
void* q_row_list[kMaxTensorsPerKernel]; // each: uint8* (M_i, K/2)
65+
void* s_dec_row_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (M_i, K/16)
66+
void* q_col_list[kMaxTensorsPerKernel]; // each: uint8* (K, M_i/2)
67+
void* s_dec_col_list[kMaxTensorsPerKernel]; // each: fp8e4m3* (K, M_i/16)
6868

6969
// Shared layout info
7070
int split_sections_range[kMaxTensorsPerKernel + 1]; // prefix sum w/ leading 0
@@ -178,13 +178,11 @@ __global__ void __launch_bounds__(kColStripThreads)
178178

179179
int cur_tensor_id = 0;
180180
while (cur_tensor_id < args.num_tensors &&
181-
args.split_sections_range[cur_tensor_id + 1] ==
182-
args.split_sections_range[cur_tensor_id]) {
181+
args.split_sections_range[cur_tensor_id + 1] == args.split_sections_range[cur_tensor_id]) {
183182
++cur_tensor_id;
184183
}
185-
int cur_tensor_end = (cur_tensor_id < args.num_tensors)
186-
? args.split_sections_range[cur_tensor_id + 1]
187-
: 0;
184+
int cur_tensor_end =
185+
(cur_tensor_id < args.num_tensors) ? args.split_sections_range[cur_tensor_id + 1] : 0;
188186

189187
// Walk M in kInnerK-row chunks. split_sections[i] % 64 == 0 implies
190188
// every chunk boundary aligns with a split boundary, so we never
@@ -194,10 +192,8 @@ __global__ void __launch_bounds__(kColStripThreads)
194192
// until cur_tensor_end > m_base. Only flush for NON-EMPTY tensors
195193
// (empty tensors' col_amax_list[] slots are NULL).
196194
while (m_base >= cur_tensor_end) {
197-
if (col_in_range && my_col_amax_acc > 0.f &&
198-
args.col_amax_list[cur_tensor_id] != nullptr) {
199-
float* dst =
200-
reinterpret_cast<float*>(args.col_amax_list[cur_tensor_id]) + my_col;
195+
if (col_in_range && my_col_amax_acc > 0.f && args.col_amax_list[cur_tensor_id] != nullptr) {
196+
float* dst = reinterpret_cast<float*>(args.col_amax_list[cur_tensor_id]) + my_col;
201197
atomicMaxFloat(dst, my_col_amax_acc);
202198
}
203199
my_col_amax_acc = 0.f;
@@ -207,8 +203,8 @@ __global__ void __launch_bounds__(kColStripThreads)
207203
}
208204
if (cur_tensor_id >= args.num_tensors) break;
209205

210-
// Per-element scan within the 16-row chunk (verbatim from single-tensor
211-
// K1 col Pass 1).
206+
// Per-element scan within the 16-row chunk (verbatim from single-tensor
207+
// K1 col Pass 1).
212208
#pragma unroll
213209
for (int e = 0; e < kInnerK; ++e) {
214210
const int gr = m_base + e;
@@ -255,17 +251,14 @@ __global__ void __launch_bounds__(kRowwiseThreads)
255251
const int local_row = global_row - args.split_sections_range[tensor_id];
256252

257253
// Read the row's outer amax (populated by K1-group rowwise).
258-
const float row_amax =
259-
reinterpret_cast<float*>(args.row_amax_list[tensor_id])[local_row];
254+
const float row_amax = reinterpret_cast<float*>(args.row_amax_list[tensor_id])[local_row];
260255
const float S_enc = compute_global_encode_scaling_factor_FP4(fmaxf(row_amax, 1e-12f));
261256

262257
// Per-tensor row base pointers.
263-
uint8_t* row_out =
264-
reinterpret_cast<uint8_t*>(args.q_row_list[tensor_id]) +
265-
static_cast<size_t>(local_row) * (K / 2);
266-
fp8e4m3* s_dec_out =
267-
reinterpret_cast<fp8e4m3*>(args.s_dec_row_list[tensor_id]) +
268-
static_cast<size_t>(local_row) * (K / kInnerK);
258+
uint8_t* row_out = reinterpret_cast<uint8_t*>(args.q_row_list[tensor_id]) +
259+
static_cast<size_t>(local_row) * (K / 2);
260+
fp8e4m3* s_dec_out = reinterpret_cast<fp8e4m3*>(args.s_dec_row_list[tensor_id]) +
261+
static_cast<size_t>(local_row) * (K / kInnerK);
269262

270263
// === verbatim from single-tensor K1 rowwise Pass 2 ===
271264
const int n_blocks = K / kInnerK;
@@ -275,8 +268,7 @@ __global__ void __launch_bounds__(kRowwiseThreads)
275268
float bmx = 0.f;
276269
#pragma unroll
277270
for (int e = 0; e < kInnerK; e++) {
278-
const float v =
279-
static_cast<float>(in[static_cast<size_t>(global_row) * K + b * kInnerK + e]);
271+
const float v = static_cast<float>(in[static_cast<size_t>(global_row) * K + b * kInnerK + e]);
280272
vals[e] = v;
281273
bmx = fmaxf(bmx, fabsf(v));
282274
}
@@ -341,12 +333,12 @@ __global__ void __launch_bounds__(kColStripThreads)
341333
// Per-tensor cached state. Initialize so the first chunk (b == 0, m_base == 0)
342334
// triggers the boundary-advance to populate these.
343335
float S_enc_cur = 0.f;
344-
int cur_tensor_id = -1; // -1 forces first-iteration advance
345-
int cur_tensor_end = 0; // exclusive
346-
int local_block_base = 0; // global block index of this tensor's first block
336+
int cur_tensor_id = -1; // -1 forces first-iteration advance
337+
int cur_tensor_end = 0; // exclusive
338+
int local_block_base = 0; // global block index of this tensor's first block
347339
uint8_t* col_out = nullptr;
348340
fp8e4m3* s_dec_col_out = nullptr;
349-
int cur_tensor_M = 0; // = split_sections[cur_tensor_id]
341+
int cur_tensor_M = 0; // = split_sections[cur_tensor_id]
350342

351343
const int n_blocks_m = sum_M / kInnerK;
352344
for (int b = 0; b < n_blocks_m; b++) {
@@ -366,8 +358,7 @@ __global__ void __launch_bounds__(kColStripThreads)
366358
need_refresh = true;
367359
}
368360
if (need_refresh && col_in_range && cur_tensor_M > 0) {
369-
const float col_amax =
370-
reinterpret_cast<float*>(args.col_amax_list[cur_tensor_id])[my_col];
361+
const float col_amax = reinterpret_cast<float*>(args.col_amax_list[cur_tensor_id])[my_col];
371362
S_enc_cur = compute_global_encode_scaling_factor_FP4(fmaxf(col_amax, 1e-12f));
372363
col_out = reinterpret_cast<uint8_t*>(args.q_col_list[cur_tensor_id]) +
373364
static_cast<size_t>(my_col) * (cur_tensor_M / 2);
@@ -499,21 +490,20 @@ void populate_args(NVFP4PerTokenMultiArgs* args, std::vector<Tensor*>& outputs,
499490
args->split_sections_range[0] = 0;
500491
for (size_t i = 0; i < num_tensors; ++i) {
501492
Tensor* o = outputs[i];
502-
NVTE_CHECK(split_sections[i] % 64 == 0, "split_sections[", i,
503-
"] = ", split_sections[i], " must be a multiple of 64");
493+
NVTE_CHECK(split_sections[i] % 64 == 0, "split_sections[", i, "] = ", split_sections[i],
494+
" must be a multiple of 64");
504495
args->split_sections_range[i + 1] =
505496
args->split_sections_range[i] + static_cast<int>(split_sections[i]);
506497
// Empty splits skip pointer validation -- the kernel boundary-advance
507498
// loop walks PAST them in zero iterations, never touching the pointer.
508499
if (split_sections[i] == 0) continue;
509500
if (which_buffers & kBufRowAmax) {
510-
NVTE_CHECK(o->amax.dptr != nullptr,
511-
"NVFP4 per-token grouped: outputs[", i, "].amax must be allocated for rowwise");
501+
NVTE_CHECK(o->amax.dptr != nullptr, "NVFP4 per-token grouped: outputs[", i,
502+
"].amax must be allocated for rowwise");
512503
args->row_amax_list[i] = o->amax.dptr;
513504
}
514505
if (which_buffers & kBufColAmax) {
515-
NVTE_CHECK(o->columnwise_amax.dptr != nullptr,
516-
"NVFP4 per-token grouped: outputs[", i,
506+
NVTE_CHECK(o->columnwise_amax.dptr != nullptr, "NVFP4 per-token grouped: outputs[", i,
517507
"].columnwise_amax must be allocated for columnwise");
518508
args->col_amax_list[i] = o->columnwise_amax.dptr;
519509
}
@@ -525,10 +515,9 @@ void populate_args(NVFP4PerTokenMultiArgs* args, std::vector<Tensor*>& outputs,
525515
args->s_dec_row_list[i] = o->scale_inv.dptr;
526516
}
527517
if (which_buffers & kBufColCast) {
528-
NVTE_CHECK(
529-
o->columnwise_data.dptr != nullptr && o->columnwise_scale_inv.dptr != nullptr,
530-
"NVFP4 per-token grouped: outputs[", i,
531-
"].columnwise_data + .columnwise_scale_inv must be allocated for columnwise cast");
518+
NVTE_CHECK(o->columnwise_data.dptr != nullptr && o->columnwise_scale_inv.dptr != nullptr,
519+
"NVFP4 per-token grouped: outputs[", i,
520+
"].columnwise_data + .columnwise_scale_inv must be allocated for columnwise cast");
532521
args->q_col_list[i] = o->columnwise_data.dptr;
533522
args->s_dec_col_list[i] = o->columnwise_scale_inv.dptr;
534523
}
@@ -557,8 +546,8 @@ void quantize_per_token_grouped(const Tensor& input, std::vector<Tensor*>& outpu
557546
const int sum_M = static_cast<int>(input.flat_first_dim());
558547
const int K = static_cast<int>(input.flat_last_dim());
559548
if (sum_M == 0 || K == 0) return;
560-
NVTE_CHECK(K % kInnerK == 0,
561-
"NVFP4 per-token grouped: K (", K, ") must be a multiple of ", kInnerK);
549+
NVTE_CHECK(K % kInnerK == 0, "NVFP4 per-token grouped: K (", K, ") must be a multiple of ",
550+
kInnerK);
562551

563552
// Amax buffer pointers must be populated whenever EITHER the K1 (writes
564553
// amax) or K2 (reads amax) pass runs in that direction. K2 reads
@@ -606,16 +595,16 @@ std::vector<transformer_engine::Tensor*> collect_outputs(NVTETensor* outputs, si
606595
} // namespace
607596

608597
void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs,
609-
const size_t* split_sections, size_t num_tensors,
610-
bool rowwise, bool columnwise, cudaStream_t stream) {
598+
const size_t* split_sections, size_t num_tensors, bool rowwise,
599+
bool columnwise, cudaStream_t stream) {
611600
#if FP4_TYPE_SUPPORTED
612601
NVTE_API_CALL(nvte_group_nvfp4_per_token_amax);
613602
using namespace transformer_engine;
614603
if (num_tensors == 0) return;
615604
const Tensor* in = convertNVTETensorCheck(input);
616605
std::vector<Tensor*> outs = collect_outputs(outputs, num_tensors);
617-
nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors,
618-
rowwise, columnwise,
606+
nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, rowwise,
607+
columnwise,
619608
/*do_amax=*/true, /*do_cast=*/false, stream);
620609
#else
621610
(void)input;
@@ -630,16 +619,16 @@ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs
630619
}
631620

632621
void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs,
633-
const size_t* split_sections, size_t num_tensors,
634-
bool rowwise, bool columnwise, cudaStream_t stream) {
622+
const size_t* split_sections, size_t num_tensors, bool rowwise,
623+
bool columnwise, cudaStream_t stream) {
635624
#if FP4_TYPE_SUPPORTED
636625
NVTE_API_CALL(nvte_group_nvfp4_per_token_cast);
637626
using namespace transformer_engine;
638627
if (num_tensors == 0) return;
639628
const Tensor* in = convertNVTETensorCheck(input);
640629
std::vector<Tensor*> outs = collect_outputs(outputs, num_tensors);
641-
nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors,
642-
rowwise, columnwise,
630+
nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, rowwise,
631+
columnwise,
643632
/*do_amax=*/false, /*do_cast=*/true, stream);
644633
#else
645634
(void)input;
@@ -662,8 +651,8 @@ void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* out
662651
if (num_tensors == 0) return;
663652
const Tensor* in = convertNVTETensorCheck(input);
664653
std::vector<Tensor*> outs = collect_outputs(outputs, num_tensors);
665-
nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors,
666-
rowwise, columnwise,
654+
nvfp4_per_token_group::quantize_per_token_grouped(*in, outs, split_sections, num_tensors, rowwise,
655+
columnwise,
667656
/*do_amax=*/true, /*do_cast=*/true, stream);
668657
#else
669658
(void)input;

transformer_engine/common/include/transformer_engine/nvfp4_per_token.h

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,30 @@
1515
extern "C" {
1616
#endif
1717

18-
1918
/*! \brief Composite K1+K2: per-row + per-col amax (K1) then FP4 + 1x16
2019
* e4m3 SF encode (K2), back-to-back on the same stream.
2120
*
2221
* This is the production entry point for the per-token cast on bf16 +
2322
* 128-aligned shapes.
2423
*/
25-
void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop,
26-
NVTETensor output, cudaStream_t stream);
24+
void nvte_nvfp4_per_token_quantize(const NVTETensor input, const NVTETensor noop, NVTETensor output,
25+
cudaStream_t stream);
2726

2827
/*! \brief Kernel 1 in isolation: per-row + per-col amax via TMA + atomicMax.
2928
* Pre-zeroes the amax buffers and merges per-CTA partials into
3029
* ``output->amax`` (size [M]) / ``output->columnwise_amax``
3130
* (size [K]). Does NOT touch FP4 data / scale_inv slots.
3231
*/
33-
void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop,
34-
NVTETensor output, cudaStream_t stream);
32+
void nvte_nvfp4_per_token_amax(const NVTETensor input, const NVTETensor noop, NVTETensor output,
33+
cudaStream_t stream);
3534

3635
/*! \brief Kernel 2 in isolation: FP4 + 1x16 e4m3 SF encode given a
3736
* pre-filled ``output->amax`` / ``output->columnwise_amax``. Reads
3837
* the outer amax buffer(s) and writes the FP4 data / scale_inv
3938
* tensors only.
4039
*/
41-
void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop,
42-
NVTETensor output, cudaStream_t stream);
40+
void nvte_nvfp4_per_token_encode(const NVTETensor input, const NVTETensor noop, NVTETensor output,
41+
cudaStream_t stream);
4342

4443
/*! \brief Returns 1 iff the per-token kernels accept ``(M, K, dtype)``.
4544
*
@@ -59,8 +58,7 @@ int nvte_nvfp4_per_token_can_dispatch(size_t M, size_t K, int input_dtype_enum);
5958
* d[i, j] = d[i, j] * row_amax_a[i] * row_amax_b[j]
6059
*/
6160
void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a,
62-
const NVTETensor row_amax_b,
63-
cudaStream_t stream);
61+
const NVTETensor row_amax_b, cudaStream_t stream);
6462

6563
/* ============================================================================
6664
* Grouped (multi-tensor) per-token quantize.
@@ -76,8 +74,8 @@ void nvte_nvfp4_per_token_post_scale(NVTETensor d, const NVTETensor row_amax_a,
7674
* \param[in] stream CUDA stream
7775
*/
7876
void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs,
79-
const size_t* split_sections, size_t num_tensors,
80-
bool rowwise, bool columnwise, cudaStream_t stream);
77+
const size_t* split_sections, size_t num_tensors, bool rowwise,
78+
bool columnwise, cudaStream_t stream);
8179

8280
/*! \brief Grouped per-token encode (FP4 + 1x16 e4m3 inner SF) using the
8381
* row_amax / col_amax values already populated by
@@ -94,8 +92,8 @@ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs
9492
* \param[in] stream CUDA stream
9593
*/
9694
void nvte_group_nvfp4_per_token_cast(const NVTETensor input, NVTETensor* outputs,
97-
const size_t* split_sections, size_t num_tensors,
98-
bool rowwise, bool columnwise, cudaStream_t stream);
95+
const size_t* split_sections, size_t num_tensors, bool rowwise,
96+
bool columnwise, cudaStream_t stream);
9997

10098
/*! \brief Composite K1+K2 grouped per-token quantize. Calls the amax + cast
10199
* kernels on the same stream. This is the external API

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -448,30 +448,27 @@ void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwi
448448
const at::Tensor &scale_inv_colwise, int rows, int cols,
449449
size_t start_offset);
450450

451-
void nvfp4_per_token_quantize(const at::Tensor &input, at::Tensor q_row,
452-
at::Tensor s_dec_row, at::Tensor row_amax,
453-
at::Tensor q_col, at::Tensor s_dec_col,
451+
void nvfp4_per_token_quantize(const at::Tensor &input, at::Tensor q_row, at::Tensor s_dec_row,
452+
at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col,
454453
at::Tensor col_amax, bool rowwise, bool columnwise);
455454

456-
void nvfp4_per_token_amax(const at::Tensor &input, at::Tensor row_amax,
457-
at::Tensor col_amax, bool rowwise, bool columnwise);
455+
void nvfp4_per_token_amax(const at::Tensor &input, at::Tensor row_amax, at::Tensor col_amax,
456+
bool rowwise, bool columnwise);
458457

459-
void nvfp4_per_token_encode(const at::Tensor &input, at::Tensor q_row,
460-
at::Tensor s_dec_row, at::Tensor row_amax,
461-
at::Tensor q_col, at::Tensor s_dec_col,
462-
at::Tensor col_amax, bool rowwise, bool columnwise);
458+
void nvfp4_per_token_encode(const at::Tensor &input, at::Tensor q_row, at::Tensor s_dec_row,
459+
at::Tensor row_amax, at::Tensor q_col, at::Tensor s_dec_col,
460+
at::Tensor col_amax, bool rowwise, bool columnwise);
463461

464462
void nvfp4_per_token_group_quantize(
465463
const at::Tensor &input, const std::vector<int64_t> &split_sections,
466464
std::vector<at::Tensor> q_row_list, std::vector<at::Tensor> s_dec_row_list,
467465
std::vector<at::Tensor> row_amax_list, std::vector<at::Tensor> q_col_list,
468-
std::vector<at::Tensor> s_dec_col_list, std::vector<at::Tensor> col_amax_list,
469-
bool rowwise, bool columnwise);
466+
std::vector<at::Tensor> s_dec_col_list, std::vector<at::Tensor> col_amax_list, bool rowwise,
467+
bool columnwise);
470468

471469
// Amax-only variant of the grouped quantize. Useful for multi-rank training
472470
// where amax is allReduced before the cast pass.
473-
void nvfp4_per_token_group_amax(const at::Tensor &input,
474-
const std::vector<int64_t> &split_sections,
471+
void nvfp4_per_token_group_amax(const at::Tensor &input, const std::vector<int64_t> &split_sections,
475472
std::vector<at::Tensor> row_amax_list,
476473
std::vector<at::Tensor> col_amax_list, bool rowwise,
477474
bool columnwise);

0 commit comments

Comments
 (0)