Skip to content

Commit bfe0916

Browse files
TimDettmersclaude
andcommitted
fix: Remap CuTE M-index from interleaved to sequential row order
The CuTE ALayout for m16n8k64 MMA uses column-major indexing which interleaves rows: [0,8], [1,9], [2,10], ... But the SM80_16x8_Row output layout expects consecutive row pairs: [0,1], [2,3], ... Diagnostic showed: A row 0 → D[0], A row 8 → D[1], A row 1 → D[2], which means the MMA maps CuTE m-indices 0,8,1,9,... to output rows 0,1,2,3,... Fix: Remap when loading A data and SFA scales: actual_m = (cute_m % 8) * 2 + cute_m / 8 This ensures A row i goes to output row i. Applied to both A data loading and SFA scale loading. B and SFB are unaffected (no interleaving issue for N-dimension). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e749d15 commit bfe0916

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

csrc/kernels_nvfp4_sm120.cu

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ __global__ void kGemmNVFP4_simple(
173173
// ALayout: coord = t0*128 + t1 + v0*16 + v1*8 + v2*512
174174
// CuTE coord space is column-major in tile: m = coord%16, k = coord/16
175175
// Value decomposition: v = v0 + v1*8 + v2*16 (v0=0..7, v1=0..1, v2=0..1)
176+
//
177+
// CRITICAL: CuTE column-major M-index interleaves rows [0,8], [1,9], ...
178+
// but the SM80_16x8 output layout expects consecutive row pairs [0,1], [2,3], ...
179+
// We remap: actual_m = (cute_m % 8) * 2 + cute_m / 8
180+
// so CuTE m=0 → actual 0, m=8 → actual 1, m=1 → actual 2, m=9 → actual 3, etc.
176181
uint32_t a_regs[4];
177182
for (int reg = 0; reg < 4; reg++) {
178183
uint32_t packed = 0;
@@ -183,8 +188,10 @@ __global__ void kGemmNVFP4_simple(
183188
int v2 = v / 16;
184189

185190
int coord = t0 * 128 + t1 + v0 * 16 + v1 * 8 + v2 * 512;
186-
int tile_row = coord % 16; // M index within tile (column-major)
191+
int cute_m = coord % 16; // CuTE M index (interleaved)
187192
int tile_col = coord / 16; // K index within tile
193+
// Remap from CuTE interleaved to sequential row order
194+
int tile_row = (cute_m % 8) * 2 + cute_m / 8;
188195

189196
int global_m = tile_m + tile_row;
190197
int global_k = k_start + tile_col;
@@ -246,8 +253,10 @@ __global__ void kGemmNVFP4_simple(
246253
int sf_thread_idx = (lane_id % 2) * 8 + (lane_id / 4);
247254
for (int sf_v = 0; sf_v < 4; sf_v++) {
248255
int sf_element = sf_thread_idx + sf_v * 16;
249-
int sf_row = sf_element % 16; // M index in tile
250-
int sf_col = sf_element / 16; // K/16 index in tile
256+
int cute_sf_m = sf_element % 16; // CuTE M index (interleaved)
257+
int sf_col = sf_element / 16; // K/16 index in tile
258+
// Same remapping as A data: CuTE interleaved → sequential
259+
int sf_row = (cute_sf_m % 8) * 2 + cute_sf_m / 8;
251260

252261
int global_m = tile_m + sf_row;
253262
int global_k_block = k_start / 16 + sf_col;

0 commit comments

Comments
 (0)