@@ -482,267 +482,3 @@ extern "C" void cgemm_nvfp4_bf16(
482482
483483 launch_gemm_nvfp4<__nv_bfloat16>(A, B, SFA, SFB, D, workspace, M, N, K, split_k, stream);
484484}
485-
486- // ============================================================================
487- // Grouped NVFP4 GEMM for MoE inference
488- //
489- // Fuses all expert GEMMs into a single kernel launch. Each threadblock
490- // handles one (m_tile, n_tile) for one expert, determined by linear
491- // blockIdx.x decomposition over a precomputed cumulative m-tile table.
492- //
493- // A_concat: [total_tokens, K/2] -- all expert activations concatenated
494- // B_all: [num_experts * N, K/2] -- per-expert weights stacked
495- // SFA_concat: [total_tokens, K/16] -- per-token activation scales
496- // SFB_all: [num_experts * N, K/16]-- per-expert weight scales stacked
497- // D_concat: [total_tokens, N] -- output (pre-allocated)
498- // expert_offsets:[num_experts+1] -- cumulative token offsets (int32)
499- // cumul_m_tiles:[num_experts+1] -- cumulative m-tile counts (int32)
500- //
501- // No split-K: expert parallelism provides sufficient tile count.
502- // CUDA-graph-safe: no dynamic allocations or cudaMemset.
503- // ============================================================================
504-
505- template <typename OutT>
506- __global__ __launch_bounds__ (WARPS_PER_BLOCK * 32 , 4 ) void kGroupedGemmNVFP4_smem(
507- const unsigned char * __restrict__ A_concat,
508- const unsigned char * __restrict__ B_all,
509- const unsigned char * __restrict__ SFA_concat,
510- const unsigned char * __restrict__ SFB_all,
511- OutT* __restrict__ D_concat,
512- const int * __restrict__ expert_offsets,
513- const int * __restrict__ cumul_m_tiles,
514- int N, int K, int num_experts
515- ) {
516- int tile_idx = blockIdx .x ;
517- int num_n_tiles = (N + BLOCK_N_DIM - 1 ) / BLOCK_N_DIM;
518- int m_tile_global = tile_idx / num_n_tiles;
519- int n_tile = tile_idx % num_n_tiles;
520-
521- // Binary search for expert owning this m_tile_global
522- int lo = 0 , hi = num_experts;
523- while (lo < hi) {
524- int mid = (lo + hi) / 2 ;
525- if (cumul_m_tiles[mid + 1 ] <= m_tile_global)
526- lo = mid + 1 ;
527- else
528- hi = mid;
529- }
530- int expert = lo;
531- if (expert >= num_experts) return ;
532-
533- int local_m_tile = m_tile_global - cumul_m_tiles[expert];
534- int expert_M = expert_offsets[expert + 1 ] - expert_offsets[expert];
535- if (expert_M <= 0 ) return ;
536-
537- int row_offset = expert_offsets[expert];
538- int half_K = K / 2 ;
539- int scale_K = K / 16 ;
540- int scale_n_col_blocks = (scale_K + 3 ) / 4 ;
541-
542- // Point to this expert's data (packed FP4 uses flat per-expert offsets,
543- // scales use swizzled layout with absolute row indices)
544- const unsigned char * A = A_concat + (size_t )row_offset * half_K;
545- const unsigned char * B = B_all + (size_t )expert * N * half_K;
546- OutT* D = D_concat + (size_t )row_offset * N;
547- int M = expert_M;
548-
549- // --- Standard tile GEMM (same logic as kGemmNVFP4_smem, no split-K) ---
550- __shared__ __align__ (16 ) unsigned char smem[SMEM_TOTAL];
551- unsigned char * smem_A = smem;
552- unsigned char * smem_B = smem + SMEM_A_BYTES;
553- unsigned char * smem_SFA = smem + SMEM_A_BYTES + SMEM_B_BYTES;
554- unsigned char * smem_SFB = smem + SMEM_A_BYTES + SMEM_B_BYTES + SMEM_SFA_BYTES;
555-
556- const int tid = threadIdx .x ;
557- const int warp_in_block = tid / 32 ;
558- const int lane_id = tid % 32 ;
559- const int m_warp = warp_in_block / N_WARPS;
560- const int n_warp = warp_in_block % N_WARPS;
561-
562- const int block_m = local_m_tile * BLOCK_M_DIM;
563- const int block_n = n_tile * BLOCK_N_DIM;
564- const int tile_m = block_m + m_warp * 16 ;
565- const int warp_n_base = block_n + n_warp * N_TILES_PER_WARP * 8 ;
566-
567- const int t0 = lane_id % 4 ;
568- const int t1 = lane_id / 4 ;
569-
570- float acc[N_TILES_PER_WARP][4 ];
571- #pragma unroll
572- for (int nt = 0 ; nt < N_TILES_PER_WARP; nt++) {
573- acc[nt][0 ] = acc[nt][1 ] = acc[nt][2 ] = acc[nt][3 ] = 0 .0f ;
574- }
575-
576- const int a_local_row0 = m_warp * 16 + 2 * t1;
577- const int a_local_row1 = a_local_row0 + 1 ;
578- const int sf_tidx = (lane_id % 2 ) * 8 + (lane_id / 4 );
579- const int cute_sf_m0 = sf_tidx % 16 ;
580- const int sfa_local_row = m_warp * 16 + (cute_sf_m0 % 8 ) * 2 + cute_sf_m0 / 8 ;
581-
582- const int a_off = tid * 4 ;
583- const int a_load_row = a_off >> 5 ;
584- const int a_load_col = a_off & 31 ;
585- const int a_gm = block_m + a_load_row;
586-
587- const int b_off = tid * 16 ;
588- const int b_load_row = b_off >> 5 ;
589- const int b_load_col = b_off & 31 ;
590- const int b_gn = block_n + b_load_row;
591-
592- const bool a_gm_ok = (a_gm < M);
593- const bool b_gn_ok = (b_gn < N);
594- const int a_row_base = a_gm * half_K;
595- const int b_row_base = b_gn * half_K;
596-
597- // Pipeline registers
598- uint32_t pipe_a = 0 ;
599- uint4 pipe_b = make_uint4 (0 , 0 , 0 , 0 );
600- uint32_t pipe_sfa = 0 , pipe_sfb = 0 ;
601-
602- // --- Pipelined K-loop with inlined load/store/compute ---
603-
604- // Load helper
605- auto do_load = [&](int k_byte, int k_scale) {
606- pipe_a = 0 ;
607- if (a_gm_ok) {
608- int ga = a_row_base + k_byte + a_load_col;
609- if (k_byte + a_load_col + 3 < half_K)
610- pipe_a = *(const uint32_t *)(A + ga);
611- else
612- for (int i = 0 ; i < 4 ; i++)
613- if (k_byte + a_load_col + i < half_K)
614- pipe_a |= ((uint32_t )A[ga + i]) << (i * 8 );
615- }
616- if (b_gn_ok) {
617- int gb = b_row_base + k_byte + b_load_col;
618- if (k_byte + b_load_col + 15 < half_K) {
619- uint4 bv = *(const uint4 *)(B + gb);
620- pipe_b.x = bv.x ; pipe_b.y = bv.y ; pipe_b.z = bv.z ; pipe_b.w = bv.w ;
621- } else {
622- unsigned char buf[16 ] = {};
623- for (int i = 0 ; i < 16 ; i++)
624- if (k_byte + b_load_col + i < half_K) buf[i] = B[gb + i];
625- pipe_b = *(uint4 *)buf;
626- }
627- } else { pipe_b = make_uint4 (0 , 0 , 0 , 0 ); }
628-
629- pipe_sfa = 0 ;
630- if (tid < BLOCK_M_DIM) {
631- int gm = block_m + tid;
632- if (gm < M) {
633- int bs = swizzled_scale_offset (row_offset + gm, k_scale, scale_n_col_blocks);
634- if (k_scale + 3 < scale_K)
635- pipe_sfa = *(const uint32_t *)(SFA_concat + bs);
636- else
637- for (int i = 0 ; i < 4 ; i++)
638- if (k_scale + i < scale_K)
639- pipe_sfa |= ((uint32_t )SFA_concat[bs + i]) << (i * 8 );
640- }
641- }
642- pipe_sfb = 0 ;
643- if (tid < BLOCK_N_DIM) {
644- int gn = block_n + tid;
645- if (gn < N) {
646- int bs = swizzled_scale_offset (expert * N + gn, k_scale, scale_n_col_blocks);
647- if (k_scale + 3 < scale_K)
648- pipe_sfb = *(const uint32_t *)(SFB_all + bs);
649- else
650- for (int i = 0 ; i < 4 ; i++)
651- if (k_scale + i < scale_K)
652- pipe_sfb |= ((uint32_t )SFB_all[bs + i]) << (i * 8 );
653- }
654- }
655- };
656-
657- auto do_store = [&]() {
658- *(uint32_t *)(smem_A + a_off) = pipe_a;
659- *(uint4 *)(smem_B + b_off) = pipe_b;
660- if (tid < BLOCK_M_DIM) *(uint32_t *)(smem_SFA + tid * 4 ) = pipe_sfa;
661- if (tid < BLOCK_N_DIM) *(uint32_t *)(smem_SFB + tid * 4 ) = pipe_sfb;
662- };
663-
664- auto do_compute = [&]() {
665- uint32_t ar[4 ];
666- ar[0 ] = *(const uint32_t *)(smem_A + a_local_row0 * 32 + t0 * 4 );
667- ar[1 ] = *(const uint32_t *)(smem_A + a_local_row1 * 32 + t0 * 4 );
668- ar[2 ] = *(const uint32_t *)(smem_A + a_local_row0 * 32 + t0 * 4 + 16 );
669- ar[3 ] = *(const uint32_t *)(smem_A + a_local_row1 * 32 + t0 * 4 + 16 );
670- uint32_t sf = *(const uint32_t *)(smem_SFA + sfa_local_row * 4 );
671- #pragma unroll
672- for (int nt = 0 ; nt < N_TILES_PER_WARP; nt++) {
673- int ln = n_warp * N_TILES_PER_WARP * 8 + nt * 8 ;
674- int br = ln + t1;
675- uint32_t b0 = *(const uint32_t *)(smem_B + br * 32 + t0 * 4 );
676- uint32_t b1 = *(const uint32_t *)(smem_B + br * 32 + t0 * 4 + 16 );
677- uint32_t sb = *(const uint32_t *)(smem_SFB + (ln + t1) * 4 );
678- mma_nvfp4_m16n8k64 (
679- acc[nt][0 ], acc[nt][1 ], acc[nt][2 ], acc[nt][3 ],
680- ar[0 ], ar[1 ], ar[2 ], ar[3 ], b0, b1,
681- acc[nt][0 ], acc[nt][1 ], acc[nt][2 ], acc[nt][3 ], sf, sb
682- );
683- }
684- };
685-
686- // Load first K-step
687- do_load (0 , 0 );
688- do_store ();
689- __syncthreads ();
690-
691- for (int k_start = 0 ; k_start < K; k_start += 64 ) {
692- bool has_next = (k_start + 64 < K);
693- if (has_next) do_load ((k_start + 64 ) / 2 , (k_start + 64 ) / 16 );
694- do_compute ();
695- __syncthreads ();
696- if (has_next) { do_store (); __syncthreads (); }
697- }
698-
699- // Write output (no split-K, direct store)
700- int octet = lane_id / 4 ;
701- int quad = lane_id % 4 ;
702- int out_row0 = tile_m + octet * 2 ;
703- int out_row1 = out_row0 + 1 ;
704- int out_col_base = quad * 2 ;
705-
706- #pragma unroll
707- for (int nt = 0 ; nt < N_TILES_PER_WARP; nt++) {
708- int this_tile_n = warp_n_base + nt * 8 ;
709- int c0 = this_tile_n + out_col_base;
710- int c1 = c0 + 1 ;
711- if (out_row0 < M && c0 < N) D[out_row0 * N + c0] = float_to_out<OutT>(acc[nt][0 ]);
712- if (out_row0 < M && c1 < N) D[out_row0 * N + c1] = float_to_out<OutT>(acc[nt][1 ]);
713- if (out_row1 < M && c0 < N) D[out_row1 * N + c0] = float_to_out<OutT>(acc[nt][2 ]);
714- if (out_row1 < M && c1 < N) D[out_row1 * N + c1] = float_to_out<OutT>(acc[nt][3 ]);
715- }
716- }
717-
718- // ============================================================================
719- // Grouped GEMM launchers
720- // ============================================================================
721-
722- template <typename OutT>
723- static void launch_grouped_gemm_nvfp4 (
724- const unsigned char * A_concat, const unsigned char * B_all,
725- const unsigned char * SFA_concat, const unsigned char * SFB_all,
726- OutT* D_concat, const int * expert_offsets, const int * cumul_m_tiles,
727- int N, int K, int num_experts, int total_tiles,
728- cudaStream_t stream
729- ) {
730- int threads_per_block = WARPS_PER_BLOCK * 32 ;
731- dim3 grid (total_tiles, 1 , 1 );
732- kGroupedGemmNVFP4_smem <OutT><<<grid, threads_per_block, 0 , stream>>> (
733- A_concat, B_all, SFA_concat, SFB_all, D_concat,
734- expert_offsets, cumul_m_tiles, N, K, num_experts
735- );
736- }
737-
738- extern " C" void cgemm_nvfp4_grouped_bf16 (
739- const unsigned char * A_concat, const unsigned char * B_all,
740- const unsigned char * SFA_concat, const unsigned char * SFB_all,
741- __nv_bfloat16* D_concat, const int * expert_offsets, const int * cumul_m_tiles,
742- int N, int K, int num_experts, int total_tiles, cudaStream_t stream
743- ) {
744- launch_grouped_gemm_nvfp4<__nv_bfloat16>(
745- A_concat, B_all, SFA_concat, SFB_all, D_concat,
746- expert_offsets, cumul_m_tiles, N, K, num_experts, total_tiles, stream
747- );
748- }
0 commit comments