Skip to content

Commit 1732592

Browse files
TimDettmersclaude
andcommitted
Remove dead hand-written grouped GEMM from kernels_nvfp4_sm120.cu
The grouped MoE kernel (kGroupedGemmNVFP4_smem, cgemm_nvfp4_grouped_bf16) was superseded by the batched CUTLASS kernel in gemm_nvfp4_sm120.cu. No Python code calls this function anymore — the SM_120 grouped path raises NotImplementedError directing callers to the batched API. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0c0f64b commit 1732592

File tree

1 file changed

+0
-264
lines changed

1 file changed

+0
-264
lines changed

csrc/kernels_nvfp4_sm120.cu

Lines changed: 0 additions & 264 deletions
Original file line numberDiff line numberDiff line change
@@ -482,267 +482,3 @@ extern "C" void cgemm_nvfp4_bf16(
482482

483483
launch_gemm_nvfp4<__nv_bfloat16>(A, B, SFA, SFB, D, workspace, M, N, K, split_k, stream);
484484
}
485-
486-
// ============================================================================
487-
// Grouped NVFP4 GEMM for MoE inference
488-
//
489-
// Fuses all expert GEMMs into a single kernel launch. Each threadblock
490-
// handles one (m_tile, n_tile) for one expert, determined by linear
491-
// blockIdx.x decomposition over a precomputed cumulative m-tile table.
492-
//
493-
// A_concat: [total_tokens, K/2] -- all expert activations concatenated
494-
// B_all: [num_experts * N, K/2] -- per-expert weights stacked
495-
// SFA_concat: [total_tokens, K/16] -- per-token activation scales
496-
// SFB_all: [num_experts * N, K/16]-- per-expert weight scales stacked
497-
// D_concat: [total_tokens, N] -- output (pre-allocated)
498-
// expert_offsets:[num_experts+1] -- cumulative token offsets (int32)
499-
// cumul_m_tiles:[num_experts+1] -- cumulative m-tile counts (int32)
500-
//
501-
// No split-K: expert parallelism provides sufficient tile count.
502-
// CUDA-graph-safe: no dynamic allocations or cudaMemset.
503-
// ============================================================================
504-
505-
template <typename OutT>
506-
__global__ __launch_bounds__(WARPS_PER_BLOCK * 32, 4) void kGroupedGemmNVFP4_smem(
507-
const unsigned char* __restrict__ A_concat,
508-
const unsigned char* __restrict__ B_all,
509-
const unsigned char* __restrict__ SFA_concat,
510-
const unsigned char* __restrict__ SFB_all,
511-
OutT* __restrict__ D_concat,
512-
const int* __restrict__ expert_offsets,
513-
const int* __restrict__ cumul_m_tiles,
514-
int N, int K, int num_experts
515-
) {
516-
int tile_idx = blockIdx.x;
517-
int num_n_tiles = (N + BLOCK_N_DIM - 1) / BLOCK_N_DIM;
518-
int m_tile_global = tile_idx / num_n_tiles;
519-
int n_tile = tile_idx % num_n_tiles;
520-
521-
// Binary search for expert owning this m_tile_global
522-
int lo = 0, hi = num_experts;
523-
while (lo < hi) {
524-
int mid = (lo + hi) / 2;
525-
if (cumul_m_tiles[mid + 1] <= m_tile_global)
526-
lo = mid + 1;
527-
else
528-
hi = mid;
529-
}
530-
int expert = lo;
531-
if (expert >= num_experts) return;
532-
533-
int local_m_tile = m_tile_global - cumul_m_tiles[expert];
534-
int expert_M = expert_offsets[expert + 1] - expert_offsets[expert];
535-
if (expert_M <= 0) return;
536-
537-
int row_offset = expert_offsets[expert];
538-
int half_K = K / 2;
539-
int scale_K = K / 16;
540-
int scale_n_col_blocks = (scale_K + 3) / 4;
541-
542-
// Point to this expert's data (packed FP4 uses flat per-expert offsets,
543-
// scales use swizzled layout with absolute row indices)
544-
const unsigned char* A = A_concat + (size_t)row_offset * half_K;
545-
const unsigned char* B = B_all + (size_t)expert * N * half_K;
546-
OutT* D = D_concat + (size_t)row_offset * N;
547-
int M = expert_M;
548-
549-
// --- Standard tile GEMM (same logic as kGemmNVFP4_smem, no split-K) ---
550-
__shared__ __align__(16) unsigned char smem[SMEM_TOTAL];
551-
unsigned char* smem_A = smem;
552-
unsigned char* smem_B = smem + SMEM_A_BYTES;
553-
unsigned char* smem_SFA = smem + SMEM_A_BYTES + SMEM_B_BYTES;
554-
unsigned char* smem_SFB = smem + SMEM_A_BYTES + SMEM_B_BYTES + SMEM_SFA_BYTES;
555-
556-
const int tid = threadIdx.x;
557-
const int warp_in_block = tid / 32;
558-
const int lane_id = tid % 32;
559-
const int m_warp = warp_in_block / N_WARPS;
560-
const int n_warp = warp_in_block % N_WARPS;
561-
562-
const int block_m = local_m_tile * BLOCK_M_DIM;
563-
const int block_n = n_tile * BLOCK_N_DIM;
564-
const int tile_m = block_m + m_warp * 16;
565-
const int warp_n_base = block_n + n_warp * N_TILES_PER_WARP * 8;
566-
567-
const int t0 = lane_id % 4;
568-
const int t1 = lane_id / 4;
569-
570-
float acc[N_TILES_PER_WARP][4];
571-
#pragma unroll
572-
for (int nt = 0; nt < N_TILES_PER_WARP; nt++) {
573-
acc[nt][0] = acc[nt][1] = acc[nt][2] = acc[nt][3] = 0.0f;
574-
}
575-
576-
const int a_local_row0 = m_warp * 16 + 2 * t1;
577-
const int a_local_row1 = a_local_row0 + 1;
578-
const int sf_tidx = (lane_id % 2) * 8 + (lane_id / 4);
579-
const int cute_sf_m0 = sf_tidx % 16;
580-
const int sfa_local_row = m_warp * 16 + (cute_sf_m0 % 8) * 2 + cute_sf_m0 / 8;
581-
582-
const int a_off = tid * 4;
583-
const int a_load_row = a_off >> 5;
584-
const int a_load_col = a_off & 31;
585-
const int a_gm = block_m + a_load_row;
586-
587-
const int b_off = tid * 16;
588-
const int b_load_row = b_off >> 5;
589-
const int b_load_col = b_off & 31;
590-
const int b_gn = block_n + b_load_row;
591-
592-
const bool a_gm_ok = (a_gm < M);
593-
const bool b_gn_ok = (b_gn < N);
594-
const int a_row_base = a_gm * half_K;
595-
const int b_row_base = b_gn * half_K;
596-
597-
// Pipeline registers
598-
uint32_t pipe_a = 0;
599-
uint4 pipe_b = make_uint4(0, 0, 0, 0);
600-
uint32_t pipe_sfa = 0, pipe_sfb = 0;
601-
602-
// --- Pipelined K-loop with inlined load/store/compute ---
603-
604-
// Load helper
605-
auto do_load = [&](int k_byte, int k_scale) {
606-
pipe_a = 0;
607-
if (a_gm_ok) {
608-
int ga = a_row_base + k_byte + a_load_col;
609-
if (k_byte + a_load_col + 3 < half_K)
610-
pipe_a = *(const uint32_t*)(A + ga);
611-
else
612-
for (int i = 0; i < 4; i++)
613-
if (k_byte + a_load_col + i < half_K)
614-
pipe_a |= ((uint32_t)A[ga + i]) << (i * 8);
615-
}
616-
if (b_gn_ok) {
617-
int gb = b_row_base + k_byte + b_load_col;
618-
if (k_byte + b_load_col + 15 < half_K) {
619-
uint4 bv = *(const uint4*)(B + gb);
620-
pipe_b.x = bv.x; pipe_b.y = bv.y; pipe_b.z = bv.z; pipe_b.w = bv.w;
621-
} else {
622-
unsigned char buf[16] = {};
623-
for (int i = 0; i < 16; i++)
624-
if (k_byte + b_load_col + i < half_K) buf[i] = B[gb + i];
625-
pipe_b = *(uint4*)buf;
626-
}
627-
} else { pipe_b = make_uint4(0, 0, 0, 0); }
628-
629-
pipe_sfa = 0;
630-
if (tid < BLOCK_M_DIM) {
631-
int gm = block_m + tid;
632-
if (gm < M) {
633-
int bs = swizzled_scale_offset(row_offset + gm, k_scale, scale_n_col_blocks);
634-
if (k_scale + 3 < scale_K)
635-
pipe_sfa = *(const uint32_t*)(SFA_concat + bs);
636-
else
637-
for (int i = 0; i < 4; i++)
638-
if (k_scale + i < scale_K)
639-
pipe_sfa |= ((uint32_t)SFA_concat[bs + i]) << (i * 8);
640-
}
641-
}
642-
pipe_sfb = 0;
643-
if (tid < BLOCK_N_DIM) {
644-
int gn = block_n + tid;
645-
if (gn < N) {
646-
int bs = swizzled_scale_offset(expert * N + gn, k_scale, scale_n_col_blocks);
647-
if (k_scale + 3 < scale_K)
648-
pipe_sfb = *(const uint32_t*)(SFB_all + bs);
649-
else
650-
for (int i = 0; i < 4; i++)
651-
if (k_scale + i < scale_K)
652-
pipe_sfb |= ((uint32_t)SFB_all[bs + i]) << (i * 8);
653-
}
654-
}
655-
};
656-
657-
auto do_store = [&]() {
658-
*(uint32_t*)(smem_A + a_off) = pipe_a;
659-
*(uint4*)(smem_B + b_off) = pipe_b;
660-
if (tid < BLOCK_M_DIM) *(uint32_t*)(smem_SFA + tid * 4) = pipe_sfa;
661-
if (tid < BLOCK_N_DIM) *(uint32_t*)(smem_SFB + tid * 4) = pipe_sfb;
662-
};
663-
664-
auto do_compute = [&]() {
665-
uint32_t ar[4];
666-
ar[0] = *(const uint32_t*)(smem_A + a_local_row0 * 32 + t0 * 4);
667-
ar[1] = *(const uint32_t*)(smem_A + a_local_row1 * 32 + t0 * 4);
668-
ar[2] = *(const uint32_t*)(smem_A + a_local_row0 * 32 + t0 * 4 + 16);
669-
ar[3] = *(const uint32_t*)(smem_A + a_local_row1 * 32 + t0 * 4 + 16);
670-
uint32_t sf = *(const uint32_t*)(smem_SFA + sfa_local_row * 4);
671-
#pragma unroll
672-
for (int nt = 0; nt < N_TILES_PER_WARP; nt++) {
673-
int ln = n_warp * N_TILES_PER_WARP * 8 + nt * 8;
674-
int br = ln + t1;
675-
uint32_t b0 = *(const uint32_t*)(smem_B + br * 32 + t0 * 4);
676-
uint32_t b1 = *(const uint32_t*)(smem_B + br * 32 + t0 * 4 + 16);
677-
uint32_t sb = *(const uint32_t*)(smem_SFB + (ln + t1) * 4);
678-
mma_nvfp4_m16n8k64(
679-
acc[nt][0], acc[nt][1], acc[nt][2], acc[nt][3],
680-
ar[0], ar[1], ar[2], ar[3], b0, b1,
681-
acc[nt][0], acc[nt][1], acc[nt][2], acc[nt][3], sf, sb
682-
);
683-
}
684-
};
685-
686-
// Load first K-step
687-
do_load(0, 0);
688-
do_store();
689-
__syncthreads();
690-
691-
for (int k_start = 0; k_start < K; k_start += 64) {
692-
bool has_next = (k_start + 64 < K);
693-
if (has_next) do_load((k_start + 64) / 2, (k_start + 64) / 16);
694-
do_compute();
695-
__syncthreads();
696-
if (has_next) { do_store(); __syncthreads(); }
697-
}
698-
699-
// Write output (no split-K, direct store)
700-
int octet = lane_id / 4;
701-
int quad = lane_id % 4;
702-
int out_row0 = tile_m + octet * 2;
703-
int out_row1 = out_row0 + 1;
704-
int out_col_base = quad * 2;
705-
706-
#pragma unroll
707-
for (int nt = 0; nt < N_TILES_PER_WARP; nt++) {
708-
int this_tile_n = warp_n_base + nt * 8;
709-
int c0 = this_tile_n + out_col_base;
710-
int c1 = c0 + 1;
711-
if (out_row0 < M && c0 < N) D[out_row0 * N + c0] = float_to_out<OutT>(acc[nt][0]);
712-
if (out_row0 < M && c1 < N) D[out_row0 * N + c1] = float_to_out<OutT>(acc[nt][1]);
713-
if (out_row1 < M && c0 < N) D[out_row1 * N + c0] = float_to_out<OutT>(acc[nt][2]);
714-
if (out_row1 < M && c1 < N) D[out_row1 * N + c1] = float_to_out<OutT>(acc[nt][3]);
715-
}
716-
}
717-
718-
// ============================================================================
719-
// Grouped GEMM launchers
720-
// ============================================================================
721-
722-
template <typename OutT>
723-
static void launch_grouped_gemm_nvfp4(
724-
const unsigned char* A_concat, const unsigned char* B_all,
725-
const unsigned char* SFA_concat, const unsigned char* SFB_all,
726-
OutT* D_concat, const int* expert_offsets, const int* cumul_m_tiles,
727-
int N, int K, int num_experts, int total_tiles,
728-
cudaStream_t stream
729-
) {
730-
int threads_per_block = WARPS_PER_BLOCK * 32;
731-
dim3 grid(total_tiles, 1, 1);
732-
kGroupedGemmNVFP4_smem<OutT><<<grid, threads_per_block, 0, stream>>>(
733-
A_concat, B_all, SFA_concat, SFB_all, D_concat,
734-
expert_offsets, cumul_m_tiles, N, K, num_experts
735-
);
736-
}
737-
738-
extern "C" void cgemm_nvfp4_grouped_bf16(
739-
const unsigned char* A_concat, const unsigned char* B_all,
740-
const unsigned char* SFA_concat, const unsigned char* SFB_all,
741-
__nv_bfloat16* D_concat, const int* expert_offsets, const int* cumul_m_tiles,
742-
int N, int K, int num_experts, int total_tiles, cudaStream_t stream
743-
) {
744-
launch_grouped_gemm_nvfp4<__nv_bfloat16>(
745-
A_concat, B_all, SFA_concat, SFB_all, D_concat,
746-
expert_offsets, cumul_m_tiles, N, K, num_experts, total_tiles, stream
747-
);
748-
}

0 commit comments

Comments
 (0)