@@ -57,14 +57,14 @@ constexpr int kMaxTensorsPerKernel = 64;
5757// ============================================================================
5858struct NVFP4PerTokenMultiArgs {
5959 // K1 outputs (per-tensor pointers; one fp32 array per tensor)
60- void * row_amax_list[kMaxTensorsPerKernel ]; // each: float* (M_i,)
61- void * col_amax_list[kMaxTensorsPerKernel ]; // each: float* (K,)
60+ void * row_amax_list[kMaxTensorsPerKernel ]; // each: float* (M_i,)
61+ void * col_amax_list[kMaxTensorsPerKernel ]; // each: float* (K,)
6262
6363 // K2 outputs (per-tensor pointers; FP4 codes + e4m3 inner SF)
64- void * q_row_list[kMaxTensorsPerKernel ]; // each: uint8* (M_i, K/2)
65- void * s_dec_row_list[kMaxTensorsPerKernel ]; // each: fp8e4m3* (M_i, K/16)
66- void * q_col_list[kMaxTensorsPerKernel ]; // each: uint8* (K, M_i/2)
67- void * s_dec_col_list[kMaxTensorsPerKernel ]; // each: fp8e4m3* (K, M_i/16)
64+ void * q_row_list[kMaxTensorsPerKernel ]; // each: uint8* (M_i, K/2)
65+ void * s_dec_row_list[kMaxTensorsPerKernel ]; // each: fp8e4m3* (M_i, K/16)
66+ void * q_col_list[kMaxTensorsPerKernel ]; // each: uint8* (K, M_i/2)
67+ void * s_dec_col_list[kMaxTensorsPerKernel ]; // each: fp8e4m3* (K, M_i/16)
6868
6969 // Shared layout info
7070 int split_sections_range[kMaxTensorsPerKernel + 1 ]; // prefix sum w/ leading 0
@@ -178,13 +178,11 @@ __global__ void __launch_bounds__(kColStripThreads)
178178
179179 int cur_tensor_id = 0 ;
180180 while (cur_tensor_id < args.num_tensors &&
181- args.split_sections_range [cur_tensor_id + 1 ] ==
182- args.split_sections_range [cur_tensor_id]) {
181+ args.split_sections_range [cur_tensor_id + 1 ] == args.split_sections_range [cur_tensor_id]) {
183182 ++cur_tensor_id;
184183 }
185- int cur_tensor_end = (cur_tensor_id < args.num_tensors )
186- ? args.split_sections_range [cur_tensor_id + 1 ]
187- : 0 ;
184+ int cur_tensor_end =
185+ (cur_tensor_id < args.num_tensors ) ? args.split_sections_range [cur_tensor_id + 1 ] : 0 ;
188186
189187 // Walk M in kInnerK-row chunks. split_sections[i] % 64 == 0 implies
190188 // every chunk boundary aligns with a split boundary, so we never
@@ -194,10 +192,8 @@ __global__ void __launch_bounds__(kColStripThreads)
194192 // until cur_tensor_end > m_base. Only flush for NON-EMPTY tensors
195193 // (empty tensors' col_amax_list[] slots are NULL).
196194 while (m_base >= cur_tensor_end) {
197- if (col_in_range && my_col_amax_acc > 0 .f &&
198- args.col_amax_list [cur_tensor_id] != nullptr ) {
199- float * dst =
200- reinterpret_cast <float *>(args.col_amax_list [cur_tensor_id]) + my_col;
195+ if (col_in_range && my_col_amax_acc > 0 .f && args.col_amax_list [cur_tensor_id] != nullptr ) {
196+ float * dst = reinterpret_cast <float *>(args.col_amax_list [cur_tensor_id]) + my_col;
201197 atomicMaxFloat (dst, my_col_amax_acc);
202198 }
203199 my_col_amax_acc = 0 .f ;
@@ -207,8 +203,8 @@ __global__ void __launch_bounds__(kColStripThreads)
207203 }
208204 if (cur_tensor_id >= args.num_tensors ) break ;
209205
210- // Per-element scan within the 16-row chunk (verbatim from single-tensor
211- // K1 col Pass 1).
206+ // Per-element scan within the 16-row chunk (verbatim from single-tensor
207+ // K1 col Pass 1).
212208#pragma unroll
213209 for (int e = 0 ; e < kInnerK ; ++e) {
214210 const int gr = m_base + e;
@@ -255,17 +251,14 @@ __global__ void __launch_bounds__(kRowwiseThreads)
255251 const int local_row = global_row - args.split_sections_range [tensor_id];
256252
257253 // Read the row's outer amax (populated by K1-group rowwise).
258- const float row_amax =
259- reinterpret_cast <float *>(args.row_amax_list [tensor_id])[local_row];
254+ const float row_amax = reinterpret_cast <float *>(args.row_amax_list [tensor_id])[local_row];
260255 const float S_enc = compute_global_encode_scaling_factor_FP4 (fmaxf (row_amax, 1e-12f ));
261256
262257 // Per-tensor row base pointers.
263- uint8_t * row_out =
264- reinterpret_cast <uint8_t *>(args.q_row_list [tensor_id]) +
265- static_cast <size_t >(local_row) * (K / 2 );
266- fp8e4m3* s_dec_out =
267- reinterpret_cast <fp8e4m3*>(args.s_dec_row_list [tensor_id]) +
268- static_cast <size_t >(local_row) * (K / kInnerK );
258+ uint8_t * row_out = reinterpret_cast <uint8_t *>(args.q_row_list [tensor_id]) +
259+ static_cast <size_t >(local_row) * (K / 2 );
260+ fp8e4m3* s_dec_out = reinterpret_cast <fp8e4m3*>(args.s_dec_row_list [tensor_id]) +
261+ static_cast <size_t >(local_row) * (K / kInnerK );
269262
270263 // === verbatim from single-tensor K1 rowwise Pass 2 ===
271264 const int n_blocks = K / kInnerK ;
@@ -275,8 +268,7 @@ __global__ void __launch_bounds__(kRowwiseThreads)
275268 float bmx = 0 .f ;
276269#pragma unroll
277270 for (int e = 0 ; e < kInnerK ; e++) {
278- const float v =
279- static_cast <float >(in[static_cast <size_t >(global_row) * K + b * kInnerK + e]);
271+ const float v = static_cast <float >(in[static_cast <size_t >(global_row) * K + b * kInnerK + e]);
280272 vals[e] = v;
281273 bmx = fmaxf (bmx, fabsf (v));
282274 }
@@ -341,12 +333,12 @@ __global__ void __launch_bounds__(kColStripThreads)
341333 // Per-tensor cached state. Initialize so the first chunk (b == 0, m_base == 0)
342334 // triggers the boundary-advance to populate these.
343335 float S_enc_cur = 0 .f ;
344- int cur_tensor_id = -1 ; // -1 forces first-iteration advance
345- int cur_tensor_end = 0 ; // exclusive
346- int local_block_base = 0 ; // global block index of this tensor's first block
336+ int cur_tensor_id = -1 ; // -1 forces first-iteration advance
337+ int cur_tensor_end = 0 ; // exclusive
338+ int local_block_base = 0 ; // global block index of this tensor's first block
347339 uint8_t * col_out = nullptr ;
348340 fp8e4m3* s_dec_col_out = nullptr ;
349- int cur_tensor_M = 0 ; // = split_sections[cur_tensor_id]
341+ int cur_tensor_M = 0 ; // = split_sections[cur_tensor_id]
350342
351343 const int n_blocks_m = sum_M / kInnerK ;
352344 for (int b = 0 ; b < n_blocks_m; b++) {
@@ -366,8 +358,7 @@ __global__ void __launch_bounds__(kColStripThreads)
366358 need_refresh = true ;
367359 }
368360 if (need_refresh && col_in_range && cur_tensor_M > 0 ) {
369- const float col_amax =
370- reinterpret_cast <float *>(args.col_amax_list [cur_tensor_id])[my_col];
361+ const float col_amax = reinterpret_cast <float *>(args.col_amax_list [cur_tensor_id])[my_col];
371362 S_enc_cur = compute_global_encode_scaling_factor_FP4 (fmaxf (col_amax, 1e-12f ));
372363 col_out = reinterpret_cast <uint8_t *>(args.q_col_list [cur_tensor_id]) +
373364 static_cast <size_t >(my_col) * (cur_tensor_M / 2 );
@@ -499,21 +490,20 @@ void populate_args(NVFP4PerTokenMultiArgs* args, std::vector<Tensor*>& outputs,
499490 args->split_sections_range [0 ] = 0 ;
500491 for (size_t i = 0 ; i < num_tensors; ++i) {
501492 Tensor* o = outputs[i];
502- NVTE_CHECK (split_sections[i] % 64 == 0 , " split_sections[" , i,
503- " ] = " , split_sections[i], " must be a multiple of 64" );
493+ NVTE_CHECK (split_sections[i] % 64 == 0 , " split_sections[" , i, " ] = " , split_sections[i],
494+ " must be a multiple of 64" );
504495 args->split_sections_range [i + 1 ] =
505496 args->split_sections_range [i] + static_cast <int >(split_sections[i]);
506497 // Empty splits skip pointer validation -- the kernel boundary-advance
507498 // loop walks PAST them in zero iterations, never touching the pointer.
508499 if (split_sections[i] == 0 ) continue ;
509500 if (which_buffers & kBufRowAmax ) {
510- NVTE_CHECK (o->amax .dptr != nullptr ,
511- " NVFP4 per-token grouped: outputs[ " , i, " ].amax must be allocated for rowwise" );
501+ NVTE_CHECK (o->amax .dptr != nullptr , " NVFP4 per-token grouped: outputs[ " , i,
502+ " ].amax must be allocated for rowwise" );
512503 args->row_amax_list [i] = o->amax .dptr ;
513504 }
514505 if (which_buffers & kBufColAmax ) {
515- NVTE_CHECK (o->columnwise_amax .dptr != nullptr ,
516- " NVFP4 per-token grouped: outputs[" , i,
506+ NVTE_CHECK (o->columnwise_amax .dptr != nullptr , " NVFP4 per-token grouped: outputs[" , i,
517507 " ].columnwise_amax must be allocated for columnwise" );
518508 args->col_amax_list [i] = o->columnwise_amax .dptr ;
519509 }
@@ -525,10 +515,9 @@ void populate_args(NVFP4PerTokenMultiArgs* args, std::vector<Tensor*>& outputs,
525515 args->s_dec_row_list [i] = o->scale_inv .dptr ;
526516 }
527517 if (which_buffers & kBufColCast ) {
528- NVTE_CHECK (
529- o->columnwise_data .dptr != nullptr && o->columnwise_scale_inv .dptr != nullptr ,
530- " NVFP4 per-token grouped: outputs[" , i,
531- " ].columnwise_data + .columnwise_scale_inv must be allocated for columnwise cast" );
518+ NVTE_CHECK (o->columnwise_data .dptr != nullptr && o->columnwise_scale_inv .dptr != nullptr ,
519+ " NVFP4 per-token grouped: outputs[" , i,
520+ " ].columnwise_data + .columnwise_scale_inv must be allocated for columnwise cast" );
532521 args->q_col_list [i] = o->columnwise_data .dptr ;
533522 args->s_dec_col_list [i] = o->columnwise_scale_inv .dptr ;
534523 }
@@ -557,8 +546,8 @@ void quantize_per_token_grouped(const Tensor& input, std::vector<Tensor*>& outpu
557546 const int sum_M = static_cast <int >(input.flat_first_dim ());
558547 const int K = static_cast <int >(input.flat_last_dim ());
559548 if (sum_M == 0 || K == 0 ) return ;
560- NVTE_CHECK (K % kInnerK == 0 ,
561- " NVFP4 per-token grouped: K ( " , K, " ) must be a multiple of " , kInnerK );
549+ NVTE_CHECK (K % kInnerK == 0 , " NVFP4 per-token grouped: K ( " , K, " ) must be a multiple of " ,
550+ kInnerK );
562551
563552 // Amax buffer pointers must be populated whenever EITHER the K1 (writes
564553 // amax) or K2 (reads amax) pass runs in that direction. K2 reads
@@ -606,16 +595,16 @@ std::vector<transformer_engine::Tensor*> collect_outputs(NVTETensor* outputs, si
606595} // namespace
607596
608597void nvte_group_nvfp4_per_token_amax (const NVTETensor input, NVTETensor* outputs,
609- const size_t * split_sections, size_t num_tensors,
610- bool rowwise, bool columnwise, cudaStream_t stream) {
598+ const size_t * split_sections, size_t num_tensors, bool rowwise,
599+ bool columnwise, cudaStream_t stream) {
611600#if FP4_TYPE_SUPPORTED
612601 NVTE_API_CALL (nvte_group_nvfp4_per_token_amax);
613602 using namespace transformer_engine ;
614603 if (num_tensors == 0 ) return ;
615604 const Tensor* in = convertNVTETensorCheck (input);
616605 std::vector<Tensor*> outs = collect_outputs (outputs, num_tensors);
617- nvfp4_per_token_group::quantize_per_token_grouped (*in, outs, split_sections, num_tensors,
618- rowwise, columnwise,
606+ nvfp4_per_token_group::quantize_per_token_grouped (*in, outs, split_sections, num_tensors, rowwise,
607+ columnwise,
619608 /* do_amax=*/ true , /* do_cast=*/ false , stream);
620609#else
621610 (void )input;
@@ -630,16 +619,16 @@ void nvte_group_nvfp4_per_token_amax(const NVTETensor input, NVTETensor* outputs
630619}
631620
632621void nvte_group_nvfp4_per_token_cast (const NVTETensor input, NVTETensor* outputs,
633- const size_t * split_sections, size_t num_tensors,
634- bool rowwise, bool columnwise, cudaStream_t stream) {
622+ const size_t * split_sections, size_t num_tensors, bool rowwise,
623+ bool columnwise, cudaStream_t stream) {
635624#if FP4_TYPE_SUPPORTED
636625 NVTE_API_CALL (nvte_group_nvfp4_per_token_cast);
637626 using namespace transformer_engine ;
638627 if (num_tensors == 0 ) return ;
639628 const Tensor* in = convertNVTETensorCheck (input);
640629 std::vector<Tensor*> outs = collect_outputs (outputs, num_tensors);
641- nvfp4_per_token_group::quantize_per_token_grouped (*in, outs, split_sections, num_tensors,
642- rowwise, columnwise,
630+ nvfp4_per_token_group::quantize_per_token_grouped (*in, outs, split_sections, num_tensors, rowwise,
631+ columnwise,
643632 /* do_amax=*/ false , /* do_cast=*/ true , stream);
644633#else
645634 (void )input;
@@ -662,8 +651,8 @@ void nvte_group_nvfp4_per_token_quantize(const NVTETensor input, NVTETensor* out
662651 if (num_tensors == 0 ) return ;
663652 const Tensor* in = convertNVTETensorCheck (input);
664653 std::vector<Tensor*> outs = collect_outputs (outputs, num_tensors);
665- nvfp4_per_token_group::quantize_per_token_grouped (*in, outs, split_sections, num_tensors,
666- rowwise, columnwise,
654+ nvfp4_per_token_group::quantize_per_token_grouped (*in, outs, split_sections, num_tensors, rowwise,
655+ columnwise,
667656 /* do_amax=*/ true , /* do_cast=*/ true , stream);
668657#else
669658 (void )input;
0 commit comments