@@ -170,30 +170,24 @@ __global__ void kGemmNVFP4_simple(
170170 // Iterate over K dimension in steps of 64
171171 for (int k_start = 0 ; k_start < K; k_start += 64 ) {
172172 // Load A registers: 4 x uint32 (32 E2M1 values per thread)
173- // Using ALayout: element_idx = t0*128 + t1 + v0*16 + v1*8 + v2*512
174- // where v = v2*16 + v1*8 + v0 (v0=0..7, v1=0..1, v2=0..1)
173+ // ALayout: coord = t0*128 + t1 + v0*16 + v1*8 + v2*512
174+ // CuTE coord space is column-major in tile: m = coord%16, k = coord/16
175+ // Value decomposition: v = v0 + v1*8 + v2*16 (v0=0..7, v1=0..1, v2=0..1)
175176 uint32_t a_regs[4 ];
176177 for (int reg = 0 ; reg < 4 ; reg++) {
177178 uint32_t packed = 0 ;
178179 for (int nib = 0 ; nib < 8 ; nib++) {
179- // v = reg * 8 + nib (value index 0..31)
180- int v0 = nib; // 0..7
181- int v1 = (reg / 1 ) % 2 ; // reg 0,1 -> v1=0; wait need to recompute
182- int v2 = reg / 2 ; // reg 0,1 -> v2=0; reg 2,3 -> v2=1
183-
184- // Actually reg maps to: reg0 = v[0..7], reg1 = v[8..15], etc.
185- // v = reg*8 + nib
186180 int v = reg * 8 + nib;
187- v0 = v % 8 ;
188- v1 = (v / 8 ) % 2 ;
189- v2 = v / 16 ;
181+ int v0 = v % 8 ;
182+ int v1 = (v / 8 ) % 2 ;
183+ int v2 = v / 16 ;
190184
191- int element_idx = t0 * 128 + t1 * 1 + v0 * 16 + v1 * 8 + v2 * 512 ;
192- int row = element_idx / 64 ; // M index within tile
193- int col = element_idx % 64 ; // K index within tile
185+ int coord = t0 * 128 + t1 + v0 * 16 + v1 * 8 + v2 * 512 ;
186+ int tile_row = coord % 16 ; // M index within tile (column-major)
187+ int tile_col = coord / 16 ; // K index within tile
194188
195- int global_m = tile_m + row ;
196- int global_k = k_start + col ;
189+ int global_m = tile_m + tile_row ;
190+ int global_k = k_start + tile_col ;
197191
198192 uint32_t nibble = 0 ;
199193 if (global_m < M && global_k < K) {
@@ -210,7 +204,8 @@ __global__ void kGemmNVFP4_simple(
210204 }
211205
212206 // Load B registers: 2 x uint32 (16 E2M1 values per thread)
213- // BLayout: element_idx = t0*64 + t1 + v0*8 + v1*256
207+ // BLayout: coord = t0*64 + t1 + v0*8 + v1*256
208+ // CuTE coord space is column-major: n = coord%8, k = coord/8
214209 uint32_t b_regs[2 ];
215210 for (int reg = 0 ; reg < 2 ; reg++) {
216211 uint32_t packed = 0 ;
@@ -219,12 +214,12 @@ __global__ void kGemmNVFP4_simple(
219214 int v0 = v % 8 ;
220215 int v1 = v / 8 ;
221216
222- int element_idx = t0 * 64 + t1 * 1 + v0 * 8 + v1 * 256 ;
223- int row = element_idx / 64 ; // N index within tile
224- int col = element_idx % 64 ; // K index within tile
217+ int coord = t0 * 64 + t1 + v0 * 8 + v1 * 256 ;
218+ int tile_row = coord % 8 ; // N index within tile (column-major)
219+ int tile_col = coord / 8 ; // K index within tile
225220
226- int global_n = tile_n + row ;
227- int global_k = k_start + col ;
221+ int global_n = tile_n + tile_row ;
222+ int global_k = k_start + tile_col ;
228223
229224 uint32_t nibble = 0 ;
230225 if (global_n < N && global_k < K) {
0 commit comments