Skip to content

Commit 45bdcce

Browse files
TimDettmersclaude
andcommitted
feat: Add simple NVFP4 GEMM kernel for SM_120a (correctness-first)
Initial GEMM implementation using mma.sync.aligned.block_scale PTX instruction. One warp per m16n8 output tile, iterating over K in steps of 64. Direct global memory loads (no shared memory staging). Includes CMakeLists.txt changes to compile with compute_120a target when SM_120 is in the target architectures. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1afdde6 commit 45bdcce

File tree

2 files changed

+368
-0
lines changed

2 files changed

+368
-0
lines changed

CMakeLists.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,24 @@ if(BUILD_CUDA)
227227

228228
list(APPEND SRC_FILES ${CUDA_FILES})
229229

230+
# SM_120a NVFP4 GEMM kernel: requires compute_120a for block-scaled MMA
231+
# Only include if 120 or 121 is in the target architectures
232+
set(_HAS_SM120 FALSE)
233+
foreach(_cap IN LISTS COMPUTE_CAPABILITY)
234+
if(_cap MATCHES "^12[01]$")
235+
set(_HAS_SM120 TRUE)
236+
endif()
237+
endforeach()
238+
if(_HAS_SM120)
239+
set(SM120A_FILE csrc/kernels_nvfp4_sm120.cu)
240+
list(APPEND SRC_FILES ${SM120A_FILE})
241+
set_source_files_properties(${SM120A_FILE} PROPERTIES
242+
COMPILE_FLAGS "-gencode=arch=compute_120a,code=sm_120a"
243+
CUDA_ARCHITECTURES "OFF"
244+
)
245+
message(STATUS "NVFP4 SM_120a GEMM kernel enabled")
246+
endif()
247+
230248
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
231249
add_compile_definitions(BUILD_CUDA)
232250
elseif(BUILD_HIP)

csrc/kernels_nvfp4_sm120.cu

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
// NVFP4 Block-Scaled GEMM Kernel for SM_120a (Blackwell Consumer GPUs)
2+
// Uses: mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X
3+
// .m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3
4+
//
5+
// Must be compiled with: -gencode=arch=compute_120a,code=sm_120a
6+
//
7+
// Computes: D = A * B (NVFP4 inputs with block scales, BF16 output)
8+
// A: M x K (row-major packed FP4, 2 values per byte)
9+
// B: K x N (column-major packed FP4, 2 values per byte)
10+
// SFA: M x (K/16) UE4M3 block scales for A
11+
// SFB: N x (K/16) UE4M3 block scales for B
12+
// D: M x N BF16 output (first version: BF16 output, not NVFP4)
13+
14+
#include <cstdint>
15+
#include <cuda_bf16.h>
16+
#include <cuda_fp16.h>
17+
#include <cuda_runtime.h>
18+
19+
// ============================================================================
20+
// MMA wrapper: m16n8k64 E2M1 x E2M1 -> F32 with UE4M3 block scales
21+
// ============================================================================
22+
__device__ __forceinline__ void mma_nvfp4_m16n8k64(
23+
float &d0, float &d1, float &d2, float &d3,
24+
uint32_t a0, uint32_t a1, uint32_t a2, uint32_t a3,
25+
uint32_t b0, uint32_t b1,
26+
float c0, float c1, float c2, float c3,
27+
uint32_t sfa, uint32_t sfb
28+
) {
29+
uint16_t bidA = 0, tidA = 0, bidB = 0, tidB = 0;
30+
asm volatile(
31+
"mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X"
32+
".m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 "
33+
"{%0, %1, %2, %3},"
34+
"{%4, %5, %6, %7},"
35+
"{%8, %9},"
36+
"{%10, %11, %12, %13},"
37+
"{%14},"
38+
"{%15, %16},"
39+
"{%17},"
40+
"{%18, %19};\n"
41+
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
42+
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
43+
"r"(b0), "r"(b1),
44+
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
45+
"r"(sfa), "h"(bidA), "h"(tidA),
46+
"r"(sfb), "h"(bidB), "h"(tidB)
47+
);
48+
}
49+
50+
// ============================================================================
51+
// Simple NVFP4 GEMM kernel (correctness-first, not performance-optimized)
52+
//
53+
// This kernel is designed for correctness verification first.
54+
// Each warp computes one m16n8 output tile, iterating over K.
55+
//
56+
// Layout assumptions:
57+
// A: M x K, row-major, packed FP4 (2 per byte). Byte [i * K/2 + k/2]
58+
// B: N x K, "column-major" meaning B is stored as N rows of K (B^T in memory).
59+
// Packed FP4. Byte [j * K/2 + k/2]. This matches TN layout for MMA.
60+
// SFA: M x (K/16), row-major UE4M3. Byte [i * (K/16) + k/16]
61+
// SFB: N x (K/16), row-major UE4M3. Byte [j * (K/16) + k/16]
62+
//
63+
// MMA register mapping (SM80_16x8_Row for C/D):
64+
// Thread tid (0-31), octet = tid/4, quad = tid%4
65+
// d[0] = C[octet*2, quad*2]
66+
// d[1] = C[octet*2, quad*2+1]
67+
// d[2] = C[octet*2+1, quad*2]
68+
// d[3] = C[octet*2+1, quad*2+1]
69+
//
70+
// A register mapping (from CUTLASS ALayout for m16n8k64):
71+
// Thread tid, 4 regs of 8 nibbles each = 32 values per thread
72+
// The layout is complex; we use ldmatrix or manual packing.
73+
//
74+
// For this first version, we use a SIMPLER approach:
75+
// - Load A and B tiles into shared memory
76+
// - Use ldmatrix.x4 to load from shared memory to registers
77+
// - This avoids needing to understand the exact register layout
78+
//
79+
// Actually, ldmatrix doesn't support FP4. So we need to understand the
80+
// register layout and pack data manually.
81+
//
82+
// MMA A register layout for m16n8k64 (from CUTLASS):
83+
// ALayout = Layout<Shape<Shape<_4,_8>, Shape<_8,_2,_2>>,
84+
// Stride<Stride<_128,_1>, Stride<_16,_8,_512>>>
85+
// This maps (T32, V32) -> element index in M16xK64 tile (row-major)
86+
//
87+
// For thread t, value v:
88+
// t0 = t/8, t1 = t%8 (thread decomposition)
89+
// v0 = v%8, v1 = (v/8)%2, v2 = v/16 (value decomposition)
90+
// element_idx = t0*128 + t1*1 + v0*16 + v1*8 + v2*512
91+
// row = element_idx / 64 (M dimension)
92+
// col = element_idx % 64 (K dimension)
93+
//
94+
// Since values are packed 8 per uint32 register:
95+
// reg[0] = values v=0..7, reg[1] = v=8..15, reg[2] = v=16..23, reg[3] = v=24..31
96+
//
97+
// MMA B register layout for m16n8k64 (from CUTLASS):
98+
// BLayout = Layout<Shape<Shape<_4,_8>, Shape<_8,_2>>,
99+
// Stride<Stride<_64,_1>, Stride<_8,_256>>>
100+
// For thread t, value v:
101+
// t0 = t/8, t1 = t%8
102+
// v0 = v%8, v1 = v/8
103+
// element_idx = t0*64 + t1*1 + v0*8 + v1*256
104+
// row = element_idx / 64 (N dimension)
105+
// col = element_idx % 64 (K dimension)
106+
//
107+
// SFA register layout:
108+
// SFALayout = Layout<Shape<Shape<_2,_2,_8>,_64>,
109+
// Stride<Stride<_8,_0,_1>,_16>>
110+
// (T32,V64) -> (M16, K64) scale factor index
111+
// The _0 stride means dimension 1 is broadcast
112+
// For thread t: t0 = t/16, t1 = (t/8)%2, t2 = t%8
113+
// Scale idx = t0*8 + t2*1 + v*16 where v=0..3 (4 SFs per row)
114+
// But with _0 stride: pairs of threads read same scales
115+
//
116+
// For this first implementation, we pack A/B/SF registers in the host
117+
// launcher and pass them via shared memory with the correct layout.
118+
// ============================================================================
119+
120+
// Helper: extract 4-bit nibble from packed byte array
121+
__device__ __forceinline__ uint32_t pack_8_nibbles(
122+
const unsigned char* data, int start_idx
123+
) {
124+
// Pack 8 consecutive 4-bit values from data starting at element index start_idx
125+
// data is packed 2 per byte (low nibble = even index, high nibble = odd index)
126+
uint32_t result = 0;
127+
for (int i = 0; i < 8; i++) {
128+
int elem_idx = start_idx + i;
129+
int byte_idx = elem_idx / 2;
130+
uint32_t nibble;
131+
if (elem_idx % 2 == 0) {
132+
nibble = data[byte_idx] & 0x0F;
133+
} else {
134+
nibble = (data[byte_idx] >> 4) & 0x0F;
135+
}
136+
result |= (nibble << (i * 4));
137+
}
138+
return result;
139+
}
140+
141+
// Simple GEMM kernel: one warp per m16n8 output tile
142+
// Each warp iterates over K in steps of 64
143+
__global__ void kGemmNVFP4_simple(
144+
const unsigned char* __restrict__ A, // M x K/2 packed FP4 (row-major)
145+
const unsigned char* __restrict__ B, // N x K/2 packed FP4 (B transposed, row-major)
146+
const unsigned char* __restrict__ SFA, // M x K/16 UE4M3 scales
147+
const unsigned char* __restrict__ SFB, // N x K/16 UE4M3 scales
148+
float* __restrict__ D, // M x N output (F32)
149+
int M, int N, int K
150+
) {
151+
// Warp-level tiling: each warp computes one m16n8 output tile
152+
int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
153+
int lane_id = threadIdx.x % 32;
154+
155+
// Map warp to output tile
156+
int num_n_tiles = (N + 7) / 8;
157+
int tile_m = (warp_id / num_n_tiles) * 16;
158+
int tile_n = (warp_id % num_n_tiles) * 8;
159+
160+
if (tile_m >= M || tile_n >= N) return;
161+
162+
// Accumulator registers
163+
float acc0 = 0.0f, acc1 = 0.0f, acc2 = 0.0f, acc3 = 0.0f;
164+
165+
// Thread layout decomposition
166+
int t0 = lane_id / 8; // 0-3
167+
int t1 = lane_id % 8; // 0-7
168+
169+
// Iterate over K dimension in steps of 64
170+
for (int k_start = 0; k_start < K; k_start += 64) {
171+
// Load A registers: 4 x uint32 (32 E2M1 values per thread)
172+
// Using ALayout: element_idx = t0*128 + t1 + v0*16 + v1*8 + v2*512
173+
// where v = v2*16 + v1*8 + v0 (v0=0..7, v1=0..1, v2=0..1)
174+
uint32_t a_regs[4];
175+
for (int reg = 0; reg < 4; reg++) {
176+
uint32_t packed = 0;
177+
for (int nib = 0; nib < 8; nib++) {
178+
// v = reg * 8 + nib (value index 0..31)
179+
int v0 = nib; // 0..7
180+
int v1 = (reg / 1) % 2; // reg 0,1 -> v1=0; wait need to recompute
181+
int v2 = reg / 2; // reg 0,1 -> v2=0; reg 2,3 -> v2=1
182+
183+
// Actually reg maps to: reg0 = v[0..7], reg1 = v[8..15], etc.
184+
// v = reg*8 + nib
185+
int v = reg * 8 + nib;
186+
v0 = v % 8;
187+
v1 = (v / 8) % 2;
188+
v2 = v / 16;
189+
190+
int element_idx = t0 * 128 + t1 * 1 + v0 * 16 + v1 * 8 + v2 * 512;
191+
int row = element_idx / 64; // M index within tile
192+
int col = element_idx % 64; // K index within tile
193+
194+
int global_m = tile_m + row;
195+
int global_k = k_start + col;
196+
197+
uint32_t nibble = 0;
198+
if (global_m < M && global_k < K) {
199+
int byte_idx = global_m * (K / 2) + global_k / 2;
200+
if (global_k % 2 == 0) {
201+
nibble = A[byte_idx] & 0x0F;
202+
} else {
203+
nibble = (A[byte_idx] >> 4) & 0x0F;
204+
}
205+
}
206+
packed |= (nibble << (nib * 4));
207+
}
208+
a_regs[reg] = packed;
209+
}
210+
211+
// Load B registers: 2 x uint32 (16 E2M1 values per thread)
212+
// BLayout: element_idx = t0*64 + t1 + v0*8 + v1*256
213+
uint32_t b_regs[2];
214+
for (int reg = 0; reg < 2; reg++) {
215+
uint32_t packed = 0;
216+
for (int nib = 0; nib < 8; nib++) {
217+
int v = reg * 8 + nib;
218+
int v0 = v % 8;
219+
int v1 = v / 8;
220+
221+
int element_idx = t0 * 64 + t1 * 1 + v0 * 8 + v1 * 256;
222+
int row = element_idx / 64; // N index within tile
223+
int col = element_idx % 64; // K index within tile
224+
225+
int global_n = tile_n + row;
226+
int global_k = k_start + col;
227+
228+
uint32_t nibble = 0;
229+
if (global_n < N && global_k < K) {
230+
int byte_idx = global_n * (K / 2) + global_k / 2;
231+
if (global_k % 2 == 0) {
232+
nibble = B[byte_idx] & 0x0F;
233+
} else {
234+
nibble = (B[byte_idx] >> 4) & 0x0F;
235+
}
236+
}
237+
packed |= (nibble << (nib * 4));
238+
}
239+
b_regs[reg] = packed;
240+
}
241+
242+
// Load SFA: 1 x uint32 (4 packed UE4M3 bytes)
243+
// SFALayout: Shape<Shape<_2,_2,_8>,_64>, Stride<Stride<_8,_0,_1>,_16>
244+
// For thread t: t_decomp = (t0_sf=t/16, t1_sf=(t/8)%2, t2_sf=t%8)
245+
// t0_sf = lane_id / 16 (0-1)
246+
// t1_sf = (lane_id / 8) % 2 (0-1, but stride=0 so broadcast)
247+
// t2_sf = lane_id % 8 (0-7)
248+
// Thread index into SF = t0_sf*8 + t2_sf = lane_id/16*8 + lane_id%8
249+
// Value dimension: 4 values (4 scale factors), stride 16
250+
// SF element = thread_idx + value_idx * 16
251+
// With M16xK64: SF has 16 rows, 4 cols (K/16=4)
252+
// thread_idx maps to the M dimension, value_idx to K/16 dimension
253+
uint32_t sfa_packed = 0;
254+
{
255+
int sf_thread_idx = (lane_id / 16) * 8 + (lane_id % 8);
256+
for (int sf_v = 0; sf_v < 4; sf_v++) {
257+
int sf_element = sf_thread_idx + sf_v * 16;
258+
int sf_row = sf_element % 16; // M index in tile
259+
int sf_col = sf_element / 16; // K/16 index in tile
260+
261+
int global_m = tile_m + sf_row;
262+
int global_k_block = k_start / 16 + sf_col;
263+
264+
unsigned char sf_val = 0;
265+
if (global_m < M && global_k_block < K / 16) {
266+
sf_val = SFA[global_m * (K / 16) + global_k_block];
267+
}
268+
sfa_packed |= ((uint32_t)sf_val << (sf_v * 8));
269+
}
270+
}
271+
272+
// Load SFB: 1 x uint32 (4 packed UE4M3 bytes)
273+
// SFBLayout: Shape<Shape<_4,_8>,_64>, Stride<Stride<_0,_1>,_8>
274+
// t0_sfb = lane_id / 8 (0-3, but stride=0 so broadcast)
275+
// t1_sfb = lane_id % 8 (0-7)
276+
// Thread idx = t1_sfb = lane_id % 8
277+
// SF element = thread_idx + value_idx * 8
278+
// With N8xK64: SF has 8 rows, 4 cols (K/16=4)
279+
uint32_t sfb_packed = 0;
280+
{
281+
int sf_thread_idx = lane_id % 8;
282+
for (int sf_v = 0; sf_v < 4; sf_v++) {
283+
int sf_element = sf_thread_idx + sf_v * 8;
284+
int sf_row = sf_element % 8; // N index in tile
285+
int sf_col = sf_element / 8; // K/16 index in tile
286+
287+
int global_n = tile_n + sf_row;
288+
int global_k_block = k_start / 16 + sf_col;
289+
290+
unsigned char sf_val = 0;
291+
if (global_n < N && global_k_block < K / 16) {
292+
sf_val = SFB[global_n * (K / 16) + global_k_block];
293+
}
294+
sfb_packed |= ((uint32_t)sf_val << (sf_v * 8));
295+
}
296+
}
297+
298+
// Execute MMA
299+
mma_nvfp4_m16n8k64(
300+
acc0, acc1, acc2, acc3,
301+
a_regs[0], a_regs[1], a_regs[2], a_regs[3],
302+
b_regs[0], b_regs[1],
303+
acc0, acc1, acc2, acc3,
304+
sfa_packed, sfb_packed
305+
);
306+
}
307+
308+
// Write output using SM80_16x8_Row layout
309+
// Thread tid, octet = tid/4, quad = tid%4
310+
// d[0] = C[octet*2, quad*2]
311+
// d[1] = C[octet*2, quad*2+1]
312+
// d[2] = C[octet*2+1, quad*2]
313+
// d[3] = C[octet*2+1, quad*2+1]
314+
int octet = lane_id / 4;
315+
int quad = lane_id % 4;
316+
317+
int out_row0 = tile_m + octet * 2;
318+
int out_row1 = tile_m + octet * 2 + 1;
319+
int out_col0 = tile_n + quad * 2;
320+
int out_col1 = tile_n + quad * 2 + 1;
321+
322+
if (out_row0 < M && out_col0 < N) D[out_row0 * N + out_col0] = acc0;
323+
if (out_row0 < M && out_col1 < N) D[out_row0 * N + out_col1] = acc1;
324+
if (out_row1 < M && out_col0 < N) D[out_row1 * N + out_col0] = acc2;
325+
if (out_row1 < M && out_col1 < N) D[out_row1 * N + out_col1] = acc3;
326+
}
327+
328+
// Host-side launcher
329+
extern "C" void cgemm_nvfp4(
330+
const unsigned char* A,
331+
const unsigned char* B,
332+
const unsigned char* SFA,
333+
const unsigned char* SFB,
334+
float* D,
335+
int M, int N, int K
336+
) {
337+
// Each warp handles one m16n8 output tile
338+
int num_m_tiles = (M + 15) / 16;
339+
int num_n_tiles = (N + 7) / 8;
340+
int total_warps = num_m_tiles * num_n_tiles;
341+
342+
// 4 warps per block (128 threads)
343+
int warps_per_block = 4;
344+
int threads_per_block = warps_per_block * 32;
345+
int num_blocks = (total_warps + warps_per_block - 1) / warps_per_block;
346+
347+
kGemmNVFP4_simple<<<num_blocks, threads_per_block>>>(
348+
A, B, SFA, SFB, D, M, N, K
349+
);
350+
}

0 commit comments

Comments
 (0)