Skip to content

Commit e749d15

Browse files
TimDettmersclaude
andcommitted
fix: Use column-major coord-to-tile mapping for MMA data registers
CuTE layout coordinates for m16n8k64 MMA tiles are column-major: A tile (16x64): m = coord%16, k = coord/16 B tile (8x64): n = coord%8, k = coord/8 Previously used row-major (m = coord/64) which placed data in wrong register positions, producing incorrect results for non-uniform data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1cfc620 commit e749d15

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

csrc/kernels_nvfp4_sm120.cu

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -170,30 +170,24 @@ __global__ void kGemmNVFP4_simple(
170170
// Iterate over K dimension in steps of 64
171171
for (int k_start = 0; k_start < K; k_start += 64) {
172172
// Load A registers: 4 x uint32 (32 E2M1 values per thread)
173-
// Using ALayout: element_idx = t0*128 + t1 + v0*16 + v1*8 + v2*512
174-
// where v = v2*16 + v1*8 + v0 (v0=0..7, v1=0..1, v2=0..1)
173+
// ALayout: coord = t0*128 + t1 + v0*16 + v1*8 + v2*512
174+
// CuTE coord space is column-major in tile: m = coord%16, k = coord/16
175+
// Value decomposition: v = v0 + v1*8 + v2*16 (v0=0..7, v1=0..1, v2=0..1)
175176
uint32_t a_regs[4];
176177
for (int reg = 0; reg < 4; reg++) {
177178
uint32_t packed = 0;
178179
for (int nib = 0; nib < 8; nib++) {
179-
// v = reg * 8 + nib (value index 0..31)
180-
int v0 = nib; // 0..7
181-
int v1 = (reg / 1) % 2; // reg 0,1 -> v1=0; wait need to recompute
182-
int v2 = reg / 2; // reg 0,1 -> v2=0; reg 2,3 -> v2=1
183-
184-
// Actually reg maps to: reg0 = v[0..7], reg1 = v[8..15], etc.
185-
// v = reg*8 + nib
186180
int v = reg * 8 + nib;
187-
v0 = v % 8;
188-
v1 = (v / 8) % 2;
189-
v2 = v / 16;
181+
int v0 = v % 8;
182+
int v1 = (v / 8) % 2;
183+
int v2 = v / 16;
190184

191-
int element_idx = t0 * 128 + t1 * 1 + v0 * 16 + v1 * 8 + v2 * 512;
192-
int row = element_idx / 64; // M index within tile
193-
int col = element_idx % 64; // K index within tile
185+
int coord = t0 * 128 + t1 + v0 * 16 + v1 * 8 + v2 * 512;
186+
int tile_row = coord % 16; // M index within tile (column-major)
187+
int tile_col = coord / 16; // K index within tile
194188

195-
int global_m = tile_m + row;
196-
int global_k = k_start + col;
189+
int global_m = tile_m + tile_row;
190+
int global_k = k_start + tile_col;
197191

198192
uint32_t nibble = 0;
199193
if (global_m < M && global_k < K) {
@@ -210,7 +204,8 @@ __global__ void kGemmNVFP4_simple(
210204
}
211205

212206
// Load B registers: 2 x uint32 (16 E2M1 values per thread)
213-
// BLayout: element_idx = t0*64 + t1 + v0*8 + v1*256
207+
// BLayout: coord = t0*64 + t1 + v0*8 + v1*256
208+
// CuTE coord space is column-major: n = coord%8, k = coord/8
214209
uint32_t b_regs[2];
215210
for (int reg = 0; reg < 2; reg++) {
216211
uint32_t packed = 0;
@@ -219,12 +214,12 @@ __global__ void kGemmNVFP4_simple(
219214
int v0 = v % 8;
220215
int v1 = v / 8;
221216

222-
int element_idx = t0 * 64 + t1 * 1 + v0 * 8 + v1 * 256;
223-
int row = element_idx / 64; // N index within tile
224-
int col = element_idx % 64; // K index within tile
217+
int coord = t0 * 64 + t1 + v0 * 8 + v1 * 256;
218+
int tile_row = coord % 8; // N index within tile (column-major)
219+
int tile_col = coord / 8; // K index within tile
225220

226-
int global_n = tile_n + row;
227-
int global_k = k_start + col;
221+
int global_n = tile_n + tile_row;
222+
int global_k = k_start + tile_col;
228223

229224
uint32_t nibble = 0;
230225
if (global_n < N && global_k < K) {

0 commit comments

Comments
 (0)