Skip to content

Commit 1cfc620

Browse files
TimDettmersclaude
andcommitted
fix: Correct CuTE thread decomposition in NVFP4 GEMM kernel
In CuTE layouts, Shape<_4,_8> means the first mode is fastest: T = t0 + t1*4, so t0 = T%4, t1 = T/4. The kernel had the inverse decomposition (t0 = T/8, t1 = T%8), which placed data in wrong register positions for the MMA instruction. Fixed all four layout mappings: - ALayout: t0=lane%4, t1=lane/4 (was lane/8, lane%8) - BLayout: same correction - SFALayout: sf_idx=(lane%2)*8+(lane/4) (was (lane/16)*8+(lane%8)) - SFBLayout: sf_idx=lane/4 (was lane%8) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 5e4df35 commit 1cfc620

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

csrc/kernels_nvfp4_sm120.cu

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,10 @@ __global__ void kGemmNVFP4_simple(
162162
// Accumulator registers
163163
float acc0 = 0.0f, acc1 = 0.0f, acc2 = 0.0f, acc3 = 0.0f;
164164

165-
// Thread layout decomposition
166-
int t0 = lane_id / 8; // 0-3
167-
int t1 = lane_id % 8; // 0-7
165+
// CuTE thread decomposition: Shape<_4,_8> means first mode is fastest
166+
// T = t0 + t1*4, so t0 = T%4 (0-3), t1 = T/4 (0-7)
167+
int t0 = lane_id % 4; // 0-3
168+
int t1 = lane_id / 4; // 0-7
168169

169170
// Iterate over K dimension in steps of 64
170171
for (int k_start = 0; k_start < K; k_start += 64) {
@@ -241,18 +242,13 @@ __global__ void kGemmNVFP4_simple(
241242

242243
// Load SFA: 1 x uint32 (4 packed UE4M3 bytes)
243244
// SFALayout: Shape<Shape<_2,_2,_8>,_64>, Stride<Stride<_8,_0,_1>,_16>
244-
// For thread t: t_decomp = (t0_sf=t/16, t1_sf=(t/8)%2, t2_sf=t%8)
245-
// t0_sf = lane_id / 16 (0-1)
246-
// t1_sf = (lane_id / 8) % 2 (0-1, but stride=0 so broadcast)
247-
// t2_sf = lane_id % 8 (0-7)
248-
// Thread index into SF = t0_sf*8 + t2_sf = lane_id/16*8 + lane_id%8
249-
// Value dimension: 4 values (4 scale factors), stride 16
250-
// SF element = thread_idx + value_idx * 16
251-
// With M16xK64: SF has 16 rows, 4 cols (K/16=4)
252-
// thread_idx maps to the M dimension, value_idx to K/16 dimension
245+
// CuTE: T = t0 + t1*2 + t2*4, so t0=T%2, t1=(T/2)%2, t2=T/4
246+
// Strides: (8, 0, 1). t1 has stride 0 (broadcast).
247+
// sf_thread_contrib = t0*8 + t2 = (lane%2)*8 + (lane/4)
248+
// SF coord = sf_thread_contrib + v*16 (column-major: m=coord%16, k_blk=coord/16)
253249
uint32_t sfa_packed = 0;
254250
{
255-
int sf_thread_idx = (lane_id / 16) * 8 + (lane_id % 8);
251+
int sf_thread_idx = (lane_id % 2) * 8 + (lane_id / 4);
256252
for (int sf_v = 0; sf_v < 4; sf_v++) {
257253
int sf_element = sf_thread_idx + sf_v * 16;
258254
int sf_row = sf_element % 16; // M index in tile
@@ -271,14 +267,12 @@ __global__ void kGemmNVFP4_simple(
271267

272268
// Load SFB: 1 x uint32 (4 packed UE4M3 bytes)
273269
// SFBLayout: Shape<Shape<_4,_8>,_64>, Stride<Stride<_0,_1>,_8>
274-
// t0_sfb = lane_id / 8 (0-3, but stride=0 so broadcast)
275-
// t1_sfb = lane_id % 8 (0-7)
276-
// Thread idx = t1_sfb = lane_id % 8
277-
// SF element = thread_idx + value_idx * 8
278-
// With N8xK64: SF has 8 rows, 4 cols (K/16=4)
270+
// CuTE: T = t0 + t1*4, so t0=T%4 (stride=0, broadcast), t1=T/4
271+
// sf_thread_contrib = t1 = lane/4
272+
// SF coord = sf_thread_contrib + v*8 (column-major: n=coord%8, k_blk=coord/8)
279273
uint32_t sfb_packed = 0;
280274
{
281-
int sf_thread_idx = lane_id % 8;
275+
int sf_thread_idx = lane_id / 4;
282276
for (int sf_v = 0; sf_v < 4; sf_v++) {
283277
int sf_element = sf_thread_idx + sf_v * 8;
284278
int sf_row = sf_element % 8; // N index in tile

0 commit comments

Comments
 (0)