|
| 1 | +// Batched NVFP4 MoE GEMM for SM_120a (consumer Blackwell: RTX 5090, RTX 6000) |
| 2 | +// |
| 3 | +// CUDA-graph friendly: fixed grid (n_tiles, m_tiles, num_experts), no dynamic |
| 4 | +// routing, no host-device sync. All experts compute max_M rows; padded rows |
| 5 | +// produce ignored output that the caller discards. |
| 6 | +// |
| 7 | +// Uses the same hand-written PTX mma.sync as the existing grouped kernel, |
| 8 | +// but with batched layout instead of concatenated+binary-search. |
| 9 | +// |
| 10 | +// Data layout: |
| 11 | +// A_batched: (num_experts, max_M, K/2) packed FP4, row-major per expert |
| 12 | +// B_all: (num_experts, N, K/2) packed FP4, row-major per expert |
| 13 | +// SFA_batched: (num_experts, sfa_per_expert_bytes) per-expert swizzled scales |
| 14 | +// SFB_all: (num_experts, sfb_per_expert_bytes) per-expert swizzled scales |
| 15 | +// D_batched: (num_experts, max_M, N) BF16 output, row-major per expert |
| 16 | +// |
| 17 | +// No alpha epilogue — tensor scales applied post-hoc in Python (same as |
| 18 | +// existing SM_120 grouped kernel). |
| 19 | +// |
| 20 | +// Must be compiled with: -gencode=arch=compute_120a,code=sm_120a |
| 21 | + |
| 22 | +#include <cstdint> |
| 23 | +#include <cuda_bf16.h> |
| 24 | +#include <cuda_fp16.h> |
| 25 | +#include <cuda_runtime.h> |
| 26 | +#include <type_traits> |
| 27 | + |
| 28 | +// ============================================================================ |
| 29 | +// MMA wrapper: m16n8k64 E2M1 x E2M1 -> F32 with UE4M3 block scales |
| 30 | +// ============================================================================ |
| 31 | +__device__ __forceinline__ void mma_nvfp4_m16n8k64( |
| 32 | + float& d0, float& d1, float& d2, float& d3, uint32_t a0, uint32_t a1, uint32_t a2, uint32_t a3, uint32_t b0, |
| 33 | + uint32_t b1, float c0, float c1, float c2, float c3, uint32_t sfa, uint32_t sfb |
| 34 | +) { |
| 35 | + uint16_t bidA = 0, tidA = 0, bidB = 0, tidB = 0; |
| 36 | + asm volatile("mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X" |
| 37 | + ".m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " |
| 38 | + "{%0, %1, %2, %3}," |
| 39 | + "{%4, %5, %6, %7}," |
| 40 | + "{%8, %9}," |
| 41 | + "{%10, %11, %12, %13}," |
| 42 | + "{%14}," |
| 43 | + "{%15, %16}," |
| 44 | + "{%17}," |
| 45 | + "{%18, %19};\n" |
| 46 | + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) |
| 47 | + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "f"(c0), "f"(c1), "f"(c2), "f"(c3), "r"(sfa), |
| 48 | + "h"(bidA), "h"(tidA), "r"(sfb), "h"(bidB), "h"(tidB)); |
| 49 | +} |
| 50 | + |
| 51 | +// ============================================================================ |
| 52 | +// Swizzled scale index computation (same as existing kernel) |
| 53 | +// ============================================================================ |
| 54 | +__device__ __forceinline__ int swizzled_scale_offset(int row, int col, int n_col_blocks) { |
| 55 | + int block_row = row >> 7; // row / 128 |
| 56 | + int block_col = col >> 2; // col / 4 |
| 57 | + int r = row & 127; // row % 128 |
| 58 | + int c = col & 3; // col % 4 |
| 59 | + int block_idx = block_row * n_col_blocks + block_col; |
| 60 | + return block_idx * 512 + (r & 31) * 16 + (r >> 5) * 4 + c; |
| 61 | +} |
| 62 | + |
| 63 | +// ============================================================================ |
| 64 | +// Output conversion helpers |
| 65 | +// ============================================================================ |
| 66 | +template <typename T> __device__ __forceinline__ T float_to_out(float v); |
| 67 | +template <> __device__ __forceinline__ float float_to_out<float>(float v) { return v; } |
| 68 | +template <> __device__ __forceinline__ __nv_bfloat16 float_to_out<__nv_bfloat16>(float v) { |
| 69 | + return __float2bfloat16(v); |
| 70 | +} |
| 71 | +template <> __device__ __forceinline__ half float_to_out<half>(float v) { return __float2half(v); } |
| 72 | + |
| 73 | +// ============================================================================ |
| 74 | +// Block tile dimensions (same as existing kernel) |
| 75 | +// ============================================================================ |
| 76 | +#define MOE_N_TILES_PER_WARP 4 |
| 77 | +#define MOE_M_WARPS 2 |
| 78 | +#define MOE_N_WARPS 4 |
| 79 | +#define MOE_WARPS_PER_BLOCK (MOE_M_WARPS * MOE_N_WARPS) // 8 |
| 80 | +#define MOE_BLOCK_M_DIM (MOE_M_WARPS * 16) // 32 |
| 81 | +#define MOE_BLOCK_N_DIM (MOE_N_WARPS * MOE_N_TILES_PER_WARP * 8) // 128 |
| 82 | +#define MOE_SMEM_A_BYTES (MOE_BLOCK_M_DIM * 32) // 1024 |
| 83 | +#define MOE_SMEM_B_BYTES (MOE_BLOCK_N_DIM * 32) // 4096 |
| 84 | +#define MOE_SMEM_SFA_BYTES (MOE_BLOCK_M_DIM * 4) // 128 |
| 85 | +#define MOE_SMEM_SFB_BYTES (MOE_BLOCK_N_DIM * 4) // 512 |
| 86 | +#define MOE_SMEM_TOTAL (MOE_SMEM_A_BYTES + MOE_SMEM_B_BYTES + MOE_SMEM_SFA_BYTES + MOE_SMEM_SFB_BYTES) |
| 87 | + |
| 88 | +// ============================================================================ |
| 89 | +// Batched MoE GEMM kernel |
| 90 | +// |
| 91 | +// Grid: (num_n_tiles, num_m_tiles, num_experts) |
| 92 | +// blockIdx.x = n_tile index |
| 93 | +// blockIdx.y = m_tile index |
| 94 | +// blockIdx.z = expert index |
| 95 | +// |
| 96 | +// Each expert computes max_M × N output with its own activation/weight data. |
| 97 | +// Padded rows (beyond actual token count) produce garbage that the caller |
| 98 | +// discards during the gather step. |
| 99 | +// ============================================================================ |
| 100 | +template <typename OutT> |
| 101 | +__global__ __launch_bounds__(MOE_WARPS_PER_BLOCK * 32, 4) void kBatchedMoeGemmNVFP4( |
| 102 | + const unsigned char* __restrict__ A_batched, // (num_experts, max_M, K/2) |
| 103 | + const unsigned char* __restrict__ B_all, // (num_experts, N, K/2) |
| 104 | + const unsigned char* __restrict__ SFA_batched, // per-expert swizzled act scales |
| 105 | + const unsigned char* __restrict__ SFB_all, // per-expert swizzled wt scales |
| 106 | + OutT* __restrict__ D_batched, // (num_experts, max_M, N) |
| 107 | + int max_M, int N, int K, |
| 108 | + int sfa_per_expert_bytes, // size of one expert's SFA block |
| 109 | + int sfb_per_expert_bytes // size of one expert's SFB block |
| 110 | +) { |
| 111 | + const int expert = blockIdx.z; |
| 112 | + const int n_tile = blockIdx.x; |
| 113 | + const int m_tile = blockIdx.y; |
| 114 | + |
| 115 | + const int half_K = K / 2; |
| 116 | + const int scale_K = K / 16; |
| 117 | + const int scale_n_col_blocks = (scale_K + 3) / 4; |
| 118 | + |
| 119 | + // Point to this expert's data in the batched layout |
| 120 | + const unsigned char* A = A_batched + (size_t)expert * max_M * half_K; |
| 121 | + const unsigned char* B = B_all + (size_t)expert * N * half_K; |
| 122 | + const unsigned char* SFA = SFA_batched + (size_t)expert * sfa_per_expert_bytes; |
| 123 | + const unsigned char* SFB = SFB_all + (size_t)expert * sfb_per_expert_bytes; |
| 124 | + OutT* D = D_batched + (size_t)expert * max_M * N; |
| 125 | + const int M = max_M; // compute all rows including padding |
| 126 | + |
| 127 | + // --- Standard tile GEMM (same core logic as kGemmNVFP4_smem) --- |
| 128 | + __shared__ __align__(16) unsigned char smem[MOE_SMEM_TOTAL]; |
| 129 | + unsigned char* smem_A = smem; |
| 130 | + unsigned char* smem_B = smem + MOE_SMEM_A_BYTES; |
| 131 | + unsigned char* smem_SFA = smem + MOE_SMEM_A_BYTES + MOE_SMEM_B_BYTES; |
| 132 | + unsigned char* smem_SFB = smem + MOE_SMEM_A_BYTES + MOE_SMEM_B_BYTES + MOE_SMEM_SFA_BYTES; |
| 133 | + |
| 134 | + const int tid = threadIdx.x; |
| 135 | + const int warp_in_block = tid / 32; |
| 136 | + const int lane_id = tid % 32; |
| 137 | + const int m_warp = warp_in_block / MOE_N_WARPS; |
| 138 | + const int n_warp = warp_in_block % MOE_N_WARPS; |
| 139 | + |
| 140 | + const int block_m = m_tile * MOE_BLOCK_M_DIM; |
| 141 | + const int block_n = n_tile * MOE_BLOCK_N_DIM; |
| 142 | + const int tile_m = block_m + m_warp * 16; |
| 143 | + const int warp_n_base = block_n + n_warp * MOE_N_TILES_PER_WARP * 8; |
| 144 | + |
| 145 | + const int t0 = lane_id % 4; |
| 146 | + const int t1 = lane_id / 4; |
| 147 | + |
| 148 | + float acc[MOE_N_TILES_PER_WARP][4]; |
| 149 | + #pragma unroll |
| 150 | + for (int nt = 0; nt < MOE_N_TILES_PER_WARP; nt++) { |
| 151 | + acc[nt][0] = acc[nt][1] = acc[nt][2] = acc[nt][3] = 0.0f; |
| 152 | + } |
| 153 | + |
| 154 | + const int a_local_row0 = m_warp * 16 + 2 * t1; |
| 155 | + const int a_local_row1 = a_local_row0 + 1; |
| 156 | + const int sf_tidx = (lane_id % 2) * 8 + (lane_id / 4); |
| 157 | + const int cute_sf_m0 = sf_tidx % 16; |
| 158 | + const int sfa_local_row = m_warp * 16 + (cute_sf_m0 % 8) * 2 + cute_sf_m0 / 8; |
| 159 | + |
| 160 | + const int a_off = tid * 4; |
| 161 | + const int a_load_row = a_off >> 5; |
| 162 | + const int a_load_col = a_off & 31; |
| 163 | + const int a_gm = block_m + a_load_row; |
| 164 | + |
| 165 | + const int b_off = tid * 16; |
| 166 | + const int b_load_row = b_off >> 5; |
| 167 | + const int b_load_col = b_off & 31; |
| 168 | + const int b_gn = block_n + b_load_row; |
| 169 | + |
| 170 | + const bool a_gm_ok = (a_gm < M); |
| 171 | + const bool b_gn_ok = (b_gn < N); |
| 172 | + const int a_row_base = a_gm * half_K; |
| 173 | + const int b_row_base = b_gn * half_K; |
| 174 | + |
| 175 | + // Pipeline registers |
| 176 | + uint32_t pipe_a = 0; |
| 177 | + uint4 pipe_b = make_uint4(0, 0, 0, 0); |
| 178 | + uint32_t pipe_sfa = 0, pipe_sfb = 0; |
| 179 | + |
| 180 | + // Load helper — uses per-expert local offsets for scales |
| 181 | + auto do_load = [&](int k_byte, int k_scale) { |
| 182 | + pipe_a = 0; |
| 183 | + if (a_gm_ok) { |
| 184 | + int ga = a_row_base + k_byte + a_load_col; |
| 185 | + if (k_byte + a_load_col + 3 < half_K) |
| 186 | + pipe_a = *(const uint32_t*)(A + ga); |
| 187 | + else |
| 188 | + for (int i = 0; i < 4; i++) |
| 189 | + if (k_byte + a_load_col + i < half_K) |
| 190 | + pipe_a |= ((uint32_t)A[ga + i]) << (i * 8); |
| 191 | + } |
| 192 | + if (b_gn_ok) { |
| 193 | + int gb = b_row_base + k_byte + b_load_col; |
| 194 | + if (k_byte + b_load_col + 15 < half_K) { |
| 195 | + uint4 bv = *(const uint4*)(B + gb); |
| 196 | + pipe_b.x = bv.x; pipe_b.y = bv.y; pipe_b.z = bv.z; pipe_b.w = bv.w; |
| 197 | + } else { |
| 198 | + unsigned char buf[16] = {}; |
| 199 | + for (int i = 0; i < 16; i++) |
| 200 | + if (k_byte + b_load_col + i < half_K) buf[i] = B[gb + i]; |
| 201 | + pipe_b = *(uint4*)buf; |
| 202 | + } |
| 203 | + } else { pipe_b = make_uint4(0, 0, 0, 0); } |
| 204 | + |
| 205 | + // SFA: per-expert swizzled layout (local row indices within this expert) |
| 206 | + pipe_sfa = 0; |
| 207 | + if (tid < MOE_BLOCK_M_DIM) { |
| 208 | + int gm = block_m + tid; |
| 209 | + if (gm < M) { |
| 210 | + int bs = swizzled_scale_offset(gm, k_scale, scale_n_col_blocks); |
| 211 | + if (k_scale + 3 < scale_K) |
| 212 | + pipe_sfa = *(const uint32_t*)(SFA + bs); |
| 213 | + else |
| 214 | + for (int i = 0; i < 4; i++) |
| 215 | + if (k_scale + i < scale_K) |
| 216 | + pipe_sfa |= ((uint32_t)SFA[bs + i]) << (i * 8); |
| 217 | + } |
| 218 | + } |
| 219 | + // SFB: per-expert swizzled layout (local row indices within this expert) |
| 220 | + pipe_sfb = 0; |
| 221 | + if (tid < MOE_BLOCK_N_DIM) { |
| 222 | + int gn = block_n + tid; |
| 223 | + if (gn < N) { |
| 224 | + int bs = swizzled_scale_offset(gn, k_scale, scale_n_col_blocks); |
| 225 | + if (k_scale + 3 < scale_K) |
| 226 | + pipe_sfb = *(const uint32_t*)(SFB + bs); |
| 227 | + else |
| 228 | + for (int i = 0; i < 4; i++) |
| 229 | + if (k_scale + i < scale_K) |
| 230 | + pipe_sfb |= ((uint32_t)SFB[bs + i]) << (i * 8); |
| 231 | + } |
| 232 | + } |
| 233 | + }; |
| 234 | + |
| 235 | + auto do_store = [&]() { |
| 236 | + *(uint32_t*)(smem_A + a_off) = pipe_a; |
| 237 | + *(uint4*)(smem_B + b_off) = pipe_b; |
| 238 | + if (tid < MOE_BLOCK_M_DIM) *(uint32_t*)(smem_SFA + tid * 4) = pipe_sfa; |
| 239 | + if (tid < MOE_BLOCK_N_DIM) *(uint32_t*)(smem_SFB + tid * 4) = pipe_sfb; |
| 240 | + }; |
| 241 | + |
| 242 | + auto do_compute = [&]() { |
| 243 | + uint32_t ar[4]; |
| 244 | + ar[0] = *(const uint32_t*)(smem_A + a_local_row0 * 32 + t0 * 4); |
| 245 | + ar[1] = *(const uint32_t*)(smem_A + a_local_row1 * 32 + t0 * 4); |
| 246 | + ar[2] = *(const uint32_t*)(smem_A + a_local_row0 * 32 + t0 * 4 + 16); |
| 247 | + ar[3] = *(const uint32_t*)(smem_A + a_local_row1 * 32 + t0 * 4 + 16); |
| 248 | + uint32_t sf = *(const uint32_t*)(smem_SFA + sfa_local_row * 4); |
| 249 | + #pragma unroll |
| 250 | + for (int nt = 0; nt < MOE_N_TILES_PER_WARP; nt++) { |
| 251 | + int ln = n_warp * MOE_N_TILES_PER_WARP * 8 + nt * 8; |
| 252 | + int br = ln + t1; |
| 253 | + uint32_t b0 = *(const uint32_t*)(smem_B + br * 32 + t0 * 4); |
| 254 | + uint32_t b1 = *(const uint32_t*)(smem_B + br * 32 + t0 * 4 + 16); |
| 255 | + uint32_t sb = *(const uint32_t*)(smem_SFB + (ln + t1) * 4); |
| 256 | + mma_nvfp4_m16n8k64( |
| 257 | + acc[nt][0], acc[nt][1], acc[nt][2], acc[nt][3], |
| 258 | + ar[0], ar[1], ar[2], ar[3], b0, b1, |
| 259 | + acc[nt][0], acc[nt][1], acc[nt][2], acc[nt][3], sf, sb |
| 260 | + ); |
| 261 | + } |
| 262 | + }; |
| 263 | + |
| 264 | + // Load first K-step |
| 265 | + do_load(0, 0); |
| 266 | + do_store(); |
| 267 | + __syncthreads(); |
| 268 | + |
| 269 | + for (int k_start = 0; k_start < K; k_start += 64) { |
| 270 | + bool has_next = (k_start + 64 < K); |
| 271 | + if (has_next) do_load((k_start + 64) / 2, (k_start + 64) / 16); |
| 272 | + do_compute(); |
| 273 | + __syncthreads(); |
| 274 | + if (has_next) { do_store(); __syncthreads(); } |
| 275 | + } |
| 276 | + |
| 277 | + // Write output (no split-K, direct store) |
| 278 | + int octet = lane_id / 4; |
| 279 | + int quad = lane_id % 4; |
| 280 | + int out_row0 = tile_m + octet * 2; |
| 281 | + int out_row1 = out_row0 + 1; |
| 282 | + int out_col_base = quad * 2; |
| 283 | + |
| 284 | + #pragma unroll |
| 285 | + for (int nt = 0; nt < MOE_N_TILES_PER_WARP; nt++) { |
| 286 | + int this_tile_n = warp_n_base + nt * 8; |
| 287 | + int c0 = this_tile_n + out_col_base; |
| 288 | + int c1 = c0 + 1; |
| 289 | + if (out_row0 < M && c0 < N) D[out_row0 * N + c0] = float_to_out<OutT>(acc[nt][0]); |
| 290 | + if (out_row0 < M && c1 < N) D[out_row0 * N + c1] = float_to_out<OutT>(acc[nt][1]); |
| 291 | + if (out_row1 < M && c0 < N) D[out_row1 * N + c0] = float_to_out<OutT>(acc[nt][2]); |
| 292 | + if (out_row1 < M && c1 < N) D[out_row1 * N + c1] = float_to_out<OutT>(acc[nt][3]); |
| 293 | + } |
| 294 | +} |
| 295 | + |
| 296 | +// ============================================================================ |
| 297 | +// Launcher and C interface |
| 298 | +// ============================================================================ |
| 299 | + |
| 300 | +template <typename OutT> |
| 301 | +static void launch_batched_moe_gemm_nvfp4( |
| 302 | + const unsigned char* A_batched, const unsigned char* B_all, |
| 303 | + const unsigned char* SFA_batched, const unsigned char* SFB_all, |
| 304 | + OutT* D_batched, |
| 305 | + int max_M, int N, int K, int num_experts, |
| 306 | + int sfa_per_expert_bytes, int sfb_per_expert_bytes, |
| 307 | + cudaStream_t stream |
| 308 | +) { |
| 309 | + int num_m_tiles = (max_M + MOE_BLOCK_M_DIM - 1) / MOE_BLOCK_M_DIM; |
| 310 | + int num_n_tiles = (N + MOE_BLOCK_N_DIM - 1) / MOE_BLOCK_N_DIM; |
| 311 | + int threads_per_block = MOE_WARPS_PER_BLOCK * 32; |
| 312 | + |
| 313 | + dim3 grid(num_n_tiles, num_m_tiles, num_experts); |
| 314 | + kBatchedMoeGemmNVFP4<OutT><<<grid, threads_per_block, 0, stream>>>( |
| 315 | + A_batched, B_all, SFA_batched, SFB_all, D_batched, |
| 316 | + max_M, N, K, sfa_per_expert_bytes, sfb_per_expert_bytes |
| 317 | + ); |
| 318 | +} |
| 319 | + |
| 320 | +extern "C" void cgemm_nvfp4_moe_bf16( |
| 321 | + const unsigned char* A_batched, |
| 322 | + const unsigned char* B_all, |
| 323 | + const unsigned char* SFA_batched, |
| 324 | + const unsigned char* SFB_all, |
| 325 | + __nv_bfloat16* D_batched, |
| 326 | + int max_M, int N, int K, int num_experts, |
| 327 | + int sfa_per_expert_bytes, int sfb_per_expert_bytes, |
| 328 | + cudaStream_t stream |
| 329 | +) { |
| 330 | + launch_batched_moe_gemm_nvfp4<__nv_bfloat16>( |
| 331 | + A_batched, B_all, SFA_batched, SFB_all, D_batched, |
| 332 | + max_M, N, K, num_experts, |
| 333 | + sfa_per_expert_bytes, sfb_per_expert_bytes, stream |
| 334 | + ); |
| 335 | +} |
0 commit comments