Skip to content

Commit 5312321

Browse files
TimDettmersclaude
andcommitted
Add CUDA graph support and consolidate dense+MoE GEMM into single TU
Consolidate dense and MoE NVFP4 GEMM into gemm_nvfp4_sm120.cu with shared types header. Add init/run split for CUDA graph capture: initGemmAdapter handles can_implement + initialize (non-capturable cudaFuncSetAttribute), launchGemm handles run-only (graph-capturable kernel launch). Key design decisions: - void* type erasure in initGemmAdapter/launchGemm to work around nvcc bug where GemmUniversalAdapter objects cannot bind to template reference params - runGemm (non-graph path) creates local Gemm objects, avoiding the bug - Persistent GemmState structs hold initialized Gemm objects between init/run - Exactly 2 device_kernel instantiations (small 128x128x128 + large 256x128x128) - Requires sm_120a (not sm_120) for block-scaled MMA instructions Tested on RTX 5090: NVFP4 1.7-10x faster than BF16 torch.mm at MoE sizes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f8aa403 commit 5312321

File tree

4 files changed

+1158
-112
lines changed

4 files changed

+1158
-112
lines changed

csrc/kernels_nvfp4_moe_sm120.cu

Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
// Batched NVFP4 MoE GEMM for SM_120a (consumer Blackwell: RTX 5090, RTX 6000)
2+
//
3+
// CUDA-graph friendly: fixed grid (n_tiles, m_tiles, num_experts), no dynamic
4+
// routing, no host-device sync. All experts compute max_M rows; padded rows
5+
// produce ignored output that the caller discards.
6+
//
7+
// Uses the same hand-written PTX mma.sync as the existing grouped kernel,
8+
// but with batched layout instead of concatenated+binary-search.
9+
//
10+
// Data layout:
11+
// A_batched: (num_experts, max_M, K/2) packed FP4, row-major per expert
12+
// B_all: (num_experts, N, K/2) packed FP4, row-major per expert
13+
// SFA_batched: (num_experts, sfa_per_expert_bytes) per-expert swizzled scales
14+
// SFB_all: (num_experts, sfb_per_expert_bytes) per-expert swizzled scales
15+
// D_batched: (num_experts, max_M, N) BF16 output, row-major per expert
16+
//
17+
// No alpha epilogue — tensor scales applied post-hoc in Python (same as
18+
// existing SM_120 grouped kernel).
19+
//
20+
// Must be compiled with: -gencode=arch=compute_120a,code=sm_120a
21+
22+
#include <cstdint>
23+
#include <cuda_bf16.h>
24+
#include <cuda_fp16.h>
25+
#include <cuda_runtime.h>
26+
#include <type_traits>
27+
28+
// ============================================================================
29+
// MMA wrapper: m16n8k64 E2M1 x E2M1 -> F32 with UE4M3 block scales
30+
// ============================================================================
31+
__device__ __forceinline__ void mma_nvfp4_m16n8k64(
32+
float& d0, float& d1, float& d2, float& d3, uint32_t a0, uint32_t a1, uint32_t a2, uint32_t a3, uint32_t b0,
33+
uint32_t b1, float c0, float c1, float c2, float c3, uint32_t sfa, uint32_t sfb
34+
) {
35+
uint16_t bidA = 0, tidA = 0, bidB = 0, tidB = 0;
36+
asm volatile("mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X"
37+
".m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 "
38+
"{%0, %1, %2, %3},"
39+
"{%4, %5, %6, %7},"
40+
"{%8, %9},"
41+
"{%10, %11, %12, %13},"
42+
"{%14},"
43+
"{%15, %16},"
44+
"{%17},"
45+
"{%18, %19};\n"
46+
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
47+
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "f"(c0), "f"(c1), "f"(c2), "f"(c3), "r"(sfa),
48+
"h"(bidA), "h"(tidA), "r"(sfb), "h"(bidB), "h"(tidB));
49+
}
50+
51+
// ============================================================================
52+
// Swizzled scale index computation (same as existing kernel)
53+
// ============================================================================
54+
__device__ __forceinline__ int swizzled_scale_offset(int row, int col, int n_col_blocks) {
55+
int block_row = row >> 7; // row / 128
56+
int block_col = col >> 2; // col / 4
57+
int r = row & 127; // row % 128
58+
int c = col & 3; // col % 4
59+
int block_idx = block_row * n_col_blocks + block_col;
60+
return block_idx * 512 + (r & 31) * 16 + (r >> 5) * 4 + c;
61+
}
62+
63+
// ============================================================================
64+
// Output conversion helpers
65+
// ============================================================================
66+
template <typename T> __device__ __forceinline__ T float_to_out(float v);
67+
template <> __device__ __forceinline__ float float_to_out<float>(float v) { return v; }
68+
template <> __device__ __forceinline__ __nv_bfloat16 float_to_out<__nv_bfloat16>(float v) {
69+
return __float2bfloat16(v);
70+
}
71+
template <> __device__ __forceinline__ half float_to_out<half>(float v) { return __float2half(v); }
72+
73+
// ============================================================================
74+
// Block tile dimensions (same as existing kernel)
75+
// ============================================================================
76+
#define MOE_N_TILES_PER_WARP 4
77+
#define MOE_M_WARPS 2
78+
#define MOE_N_WARPS 4
79+
#define MOE_WARPS_PER_BLOCK (MOE_M_WARPS * MOE_N_WARPS) // 8
80+
#define MOE_BLOCK_M_DIM (MOE_M_WARPS * 16) // 32
81+
#define MOE_BLOCK_N_DIM (MOE_N_WARPS * MOE_N_TILES_PER_WARP * 8) // 128
82+
#define MOE_SMEM_A_BYTES (MOE_BLOCK_M_DIM * 32) // 1024
83+
#define MOE_SMEM_B_BYTES (MOE_BLOCK_N_DIM * 32) // 4096
84+
#define MOE_SMEM_SFA_BYTES (MOE_BLOCK_M_DIM * 4) // 128
85+
#define MOE_SMEM_SFB_BYTES (MOE_BLOCK_N_DIM * 4) // 512
86+
#define MOE_SMEM_TOTAL (MOE_SMEM_A_BYTES + MOE_SMEM_B_BYTES + MOE_SMEM_SFA_BYTES + MOE_SMEM_SFB_BYTES)
87+
88+
// ============================================================================
89+
// Batched MoE GEMM kernel
90+
//
91+
// Grid: (num_n_tiles, num_m_tiles, num_experts)
92+
// blockIdx.x = n_tile index
93+
// blockIdx.y = m_tile index
94+
// blockIdx.z = expert index
95+
//
96+
// Each expert computes max_M × N output with its own activation/weight data.
97+
// Padded rows (beyond actual token count) produce garbage that the caller
98+
// discards during the gather step.
99+
// ============================================================================
100+
template <typename OutT>
101+
__global__ __launch_bounds__(MOE_WARPS_PER_BLOCK * 32, 4) void kBatchedMoeGemmNVFP4(
102+
const unsigned char* __restrict__ A_batched, // (num_experts, max_M, K/2)
103+
const unsigned char* __restrict__ B_all, // (num_experts, N, K/2)
104+
const unsigned char* __restrict__ SFA_batched, // per-expert swizzled act scales
105+
const unsigned char* __restrict__ SFB_all, // per-expert swizzled wt scales
106+
OutT* __restrict__ D_batched, // (num_experts, max_M, N)
107+
int max_M, int N, int K,
108+
int sfa_per_expert_bytes, // size of one expert's SFA block
109+
int sfb_per_expert_bytes // size of one expert's SFB block
110+
) {
111+
const int expert = blockIdx.z;
112+
const int n_tile = blockIdx.x;
113+
const int m_tile = blockIdx.y;
114+
115+
const int half_K = K / 2;
116+
const int scale_K = K / 16;
117+
const int scale_n_col_blocks = (scale_K + 3) / 4;
118+
119+
// Point to this expert's data in the batched layout
120+
const unsigned char* A = A_batched + (size_t)expert * max_M * half_K;
121+
const unsigned char* B = B_all + (size_t)expert * N * half_K;
122+
const unsigned char* SFA = SFA_batched + (size_t)expert * sfa_per_expert_bytes;
123+
const unsigned char* SFB = SFB_all + (size_t)expert * sfb_per_expert_bytes;
124+
OutT* D = D_batched + (size_t)expert * max_M * N;
125+
const int M = max_M; // compute all rows including padding
126+
127+
// --- Standard tile GEMM (same core logic as kGemmNVFP4_smem) ---
128+
__shared__ __align__(16) unsigned char smem[MOE_SMEM_TOTAL];
129+
unsigned char* smem_A = smem;
130+
unsigned char* smem_B = smem + MOE_SMEM_A_BYTES;
131+
unsigned char* smem_SFA = smem + MOE_SMEM_A_BYTES + MOE_SMEM_B_BYTES;
132+
unsigned char* smem_SFB = smem + MOE_SMEM_A_BYTES + MOE_SMEM_B_BYTES + MOE_SMEM_SFA_BYTES;
133+
134+
const int tid = threadIdx.x;
135+
const int warp_in_block = tid / 32;
136+
const int lane_id = tid % 32;
137+
const int m_warp = warp_in_block / MOE_N_WARPS;
138+
const int n_warp = warp_in_block % MOE_N_WARPS;
139+
140+
const int block_m = m_tile * MOE_BLOCK_M_DIM;
141+
const int block_n = n_tile * MOE_BLOCK_N_DIM;
142+
const int tile_m = block_m + m_warp * 16;
143+
const int warp_n_base = block_n + n_warp * MOE_N_TILES_PER_WARP * 8;
144+
145+
const int t0 = lane_id % 4;
146+
const int t1 = lane_id / 4;
147+
148+
float acc[MOE_N_TILES_PER_WARP][4];
149+
#pragma unroll
150+
for (int nt = 0; nt < MOE_N_TILES_PER_WARP; nt++) {
151+
acc[nt][0] = acc[nt][1] = acc[nt][2] = acc[nt][3] = 0.0f;
152+
}
153+
154+
const int a_local_row0 = m_warp * 16 + 2 * t1;
155+
const int a_local_row1 = a_local_row0 + 1;
156+
const int sf_tidx = (lane_id % 2) * 8 + (lane_id / 4);
157+
const int cute_sf_m0 = sf_tidx % 16;
158+
const int sfa_local_row = m_warp * 16 + (cute_sf_m0 % 8) * 2 + cute_sf_m0 / 8;
159+
160+
const int a_off = tid * 4;
161+
const int a_load_row = a_off >> 5;
162+
const int a_load_col = a_off & 31;
163+
const int a_gm = block_m + a_load_row;
164+
165+
const int b_off = tid * 16;
166+
const int b_load_row = b_off >> 5;
167+
const int b_load_col = b_off & 31;
168+
const int b_gn = block_n + b_load_row;
169+
170+
const bool a_gm_ok = (a_gm < M);
171+
const bool b_gn_ok = (b_gn < N);
172+
const int a_row_base = a_gm * half_K;
173+
const int b_row_base = b_gn * half_K;
174+
175+
// Pipeline registers
176+
uint32_t pipe_a = 0;
177+
uint4 pipe_b = make_uint4(0, 0, 0, 0);
178+
uint32_t pipe_sfa = 0, pipe_sfb = 0;
179+
180+
// Load helper — uses per-expert local offsets for scales
181+
auto do_load = [&](int k_byte, int k_scale) {
182+
pipe_a = 0;
183+
if (a_gm_ok) {
184+
int ga = a_row_base + k_byte + a_load_col;
185+
if (k_byte + a_load_col + 3 < half_K)
186+
pipe_a = *(const uint32_t*)(A + ga);
187+
else
188+
for (int i = 0; i < 4; i++)
189+
if (k_byte + a_load_col + i < half_K)
190+
pipe_a |= ((uint32_t)A[ga + i]) << (i * 8);
191+
}
192+
if (b_gn_ok) {
193+
int gb = b_row_base + k_byte + b_load_col;
194+
if (k_byte + b_load_col + 15 < half_K) {
195+
uint4 bv = *(const uint4*)(B + gb);
196+
pipe_b.x = bv.x; pipe_b.y = bv.y; pipe_b.z = bv.z; pipe_b.w = bv.w;
197+
} else {
198+
unsigned char buf[16] = {};
199+
for (int i = 0; i < 16; i++)
200+
if (k_byte + b_load_col + i < half_K) buf[i] = B[gb + i];
201+
pipe_b = *(uint4*)buf;
202+
}
203+
} else { pipe_b = make_uint4(0, 0, 0, 0); }
204+
205+
// SFA: per-expert swizzled layout (local row indices within this expert)
206+
pipe_sfa = 0;
207+
if (tid < MOE_BLOCK_M_DIM) {
208+
int gm = block_m + tid;
209+
if (gm < M) {
210+
int bs = swizzled_scale_offset(gm, k_scale, scale_n_col_blocks);
211+
if (k_scale + 3 < scale_K)
212+
pipe_sfa = *(const uint32_t*)(SFA + bs);
213+
else
214+
for (int i = 0; i < 4; i++)
215+
if (k_scale + i < scale_K)
216+
pipe_sfa |= ((uint32_t)SFA[bs + i]) << (i * 8);
217+
}
218+
}
219+
// SFB: per-expert swizzled layout (local row indices within this expert)
220+
pipe_sfb = 0;
221+
if (tid < MOE_BLOCK_N_DIM) {
222+
int gn = block_n + tid;
223+
if (gn < N) {
224+
int bs = swizzled_scale_offset(gn, k_scale, scale_n_col_blocks);
225+
if (k_scale + 3 < scale_K)
226+
pipe_sfb = *(const uint32_t*)(SFB + bs);
227+
else
228+
for (int i = 0; i < 4; i++)
229+
if (k_scale + i < scale_K)
230+
pipe_sfb |= ((uint32_t)SFB[bs + i]) << (i * 8);
231+
}
232+
}
233+
};
234+
235+
auto do_store = [&]() {
236+
*(uint32_t*)(smem_A + a_off) = pipe_a;
237+
*(uint4*)(smem_B + b_off) = pipe_b;
238+
if (tid < MOE_BLOCK_M_DIM) *(uint32_t*)(smem_SFA + tid * 4) = pipe_sfa;
239+
if (tid < MOE_BLOCK_N_DIM) *(uint32_t*)(smem_SFB + tid * 4) = pipe_sfb;
240+
};
241+
242+
auto do_compute = [&]() {
243+
uint32_t ar[4];
244+
ar[0] = *(const uint32_t*)(smem_A + a_local_row0 * 32 + t0 * 4);
245+
ar[1] = *(const uint32_t*)(smem_A + a_local_row1 * 32 + t0 * 4);
246+
ar[2] = *(const uint32_t*)(smem_A + a_local_row0 * 32 + t0 * 4 + 16);
247+
ar[3] = *(const uint32_t*)(smem_A + a_local_row1 * 32 + t0 * 4 + 16);
248+
uint32_t sf = *(const uint32_t*)(smem_SFA + sfa_local_row * 4);
249+
#pragma unroll
250+
for (int nt = 0; nt < MOE_N_TILES_PER_WARP; nt++) {
251+
int ln = n_warp * MOE_N_TILES_PER_WARP * 8 + nt * 8;
252+
int br = ln + t1;
253+
uint32_t b0 = *(const uint32_t*)(smem_B + br * 32 + t0 * 4);
254+
uint32_t b1 = *(const uint32_t*)(smem_B + br * 32 + t0 * 4 + 16);
255+
uint32_t sb = *(const uint32_t*)(smem_SFB + (ln + t1) * 4);
256+
mma_nvfp4_m16n8k64(
257+
acc[nt][0], acc[nt][1], acc[nt][2], acc[nt][3],
258+
ar[0], ar[1], ar[2], ar[3], b0, b1,
259+
acc[nt][0], acc[nt][1], acc[nt][2], acc[nt][3], sf, sb
260+
);
261+
}
262+
};
263+
264+
// Load first K-step
265+
do_load(0, 0);
266+
do_store();
267+
__syncthreads();
268+
269+
for (int k_start = 0; k_start < K; k_start += 64) {
270+
bool has_next = (k_start + 64 < K);
271+
if (has_next) do_load((k_start + 64) / 2, (k_start + 64) / 16);
272+
do_compute();
273+
__syncthreads();
274+
if (has_next) { do_store(); __syncthreads(); }
275+
}
276+
277+
// Write output (no split-K, direct store)
278+
int octet = lane_id / 4;
279+
int quad = lane_id % 4;
280+
int out_row0 = tile_m + octet * 2;
281+
int out_row1 = out_row0 + 1;
282+
int out_col_base = quad * 2;
283+
284+
#pragma unroll
285+
for (int nt = 0; nt < MOE_N_TILES_PER_WARP; nt++) {
286+
int this_tile_n = warp_n_base + nt * 8;
287+
int c0 = this_tile_n + out_col_base;
288+
int c1 = c0 + 1;
289+
if (out_row0 < M && c0 < N) D[out_row0 * N + c0] = float_to_out<OutT>(acc[nt][0]);
290+
if (out_row0 < M && c1 < N) D[out_row0 * N + c1] = float_to_out<OutT>(acc[nt][1]);
291+
if (out_row1 < M && c0 < N) D[out_row1 * N + c0] = float_to_out<OutT>(acc[nt][2]);
292+
if (out_row1 < M && c1 < N) D[out_row1 * N + c1] = float_to_out<OutT>(acc[nt][3]);
293+
}
294+
}
295+
296+
// ============================================================================
297+
// Launcher and C interface
298+
// ============================================================================
299+
300+
template <typename OutT>
301+
static void launch_batched_moe_gemm_nvfp4(
302+
const unsigned char* A_batched, const unsigned char* B_all,
303+
const unsigned char* SFA_batched, const unsigned char* SFB_all,
304+
OutT* D_batched,
305+
int max_M, int N, int K, int num_experts,
306+
int sfa_per_expert_bytes, int sfb_per_expert_bytes,
307+
cudaStream_t stream
308+
) {
309+
int num_m_tiles = (max_M + MOE_BLOCK_M_DIM - 1) / MOE_BLOCK_M_DIM;
310+
int num_n_tiles = (N + MOE_BLOCK_N_DIM - 1) / MOE_BLOCK_N_DIM;
311+
int threads_per_block = MOE_WARPS_PER_BLOCK * 32;
312+
313+
dim3 grid(num_n_tiles, num_m_tiles, num_experts);
314+
kBatchedMoeGemmNVFP4<OutT><<<grid, threads_per_block, 0, stream>>>(
315+
A_batched, B_all, SFA_batched, SFB_all, D_batched,
316+
max_M, N, K, sfa_per_expert_bytes, sfb_per_expert_bytes
317+
);
318+
}
319+
320+
extern "C" void cgemm_nvfp4_moe_bf16(
321+
const unsigned char* A_batched,
322+
const unsigned char* B_all,
323+
const unsigned char* SFA_batched,
324+
const unsigned char* SFB_all,
325+
__nv_bfloat16* D_batched,
326+
int max_M, int N, int K, int num_experts,
327+
int sfa_per_expert_bytes, int sfb_per_expert_bytes,
328+
cudaStream_t stream
329+
) {
330+
launch_batched_moe_gemm_nvfp4<__nv_bfloat16>(
331+
A_batched, B_all, SFA_batched, SFB_all, D_batched,
332+
max_M, N, K, num_experts,
333+
sfa_per_expert_bytes, sfb_per_expert_bytes, stream
334+
);
335+
}

0 commit comments

Comments
 (0)