@@ -88,32 +88,41 @@ __device__ __forceinline__ uint32_t pack_8_nibbles_slow(const unsigned char* dat
8888
8989// N-tiles per warp: each warp computes m16 x (N_TILE_PER_WARP * 8)
9090#define N_TILES_PER_WARP 4
91- #define WARPS_PER_BLOCK 4
91+ // Block config: M_WARPS x N_WARPS warps per block
92+ // M_WARPS groups along M (each m16), N_WARPS groups along N (each handles N_TILES_PER_WARP n8-tiles)
93+ #define M_WARPS 2
94+ #define N_WARPS 4
95+ #define WARPS_PER_BLOCK (M_WARPS * N_WARPS) // 8
9296
93- __global__ void kGemmNVFP4_opt (
97+ __global__ __launch_bounds__ (WARPS_PER_BLOCK * 32 , 2 ) void kGemmNVFP4_opt(
9498 const unsigned char * __restrict__ A, // M x K/2 packed FP4 (row-major)
9599 const unsigned char * __restrict__ B, // N x K/2 packed FP4 (B transposed, row-major)
96100 const unsigned char * __restrict__ SFA, // M x K/16 UE4M3 scales
97101 const unsigned char * __restrict__ SFB, // N x K/16 UE4M3 scales
98102 float * __restrict__ D, // M x N output (F32)
99103 int M, int N, int K
100104) {
101- // Block tile: m16 x n(WARPS_PER_BLOCK * N_TILES_PER_WARP * 8)
102- // = m16 x n128 for 4 warps with 4 n-tiles each
103- const int BLOCK_N = WARPS_PER_BLOCK * N_TILES_PER_WARP * 8 ;
105+ // Block tile: m(M_WARPS*16) x n(N_WARPS * N_TILES_PER_WARP * 8)
106+ // = m32 x n128 for 2x4 warps
107+ const int BLOCK_M = M_WARPS * 16 ;
108+ const int BLOCK_N = N_WARPS * N_TILES_PER_WARP * 8 ;
104109
105110 int warp_in_block = threadIdx .x / 32 ;
106111 int lane_id = threadIdx .x % 32 ;
107112
113+ // 2D warp mapping: m_warp along M, n_warp along N
114+ int m_warp = warp_in_block / N_WARPS; // 0..(M_WARPS-1)
115+ int n_warp = warp_in_block % N_WARPS; // 0..(N_WARPS-1)
116+
108117 // Block-level tile position
109- int tile_m = blockIdx .y * 16 ;
118+ int tile_m = blockIdx .y * BLOCK_M + m_warp * 16 ;
110119 int tile_n_base = blockIdx .x * BLOCK_N;
111120
112121 if (tile_m >= M)
113122 return ;
114123
115124 // This warp's N offset within the block
116- int warp_n_base = tile_n_base + warp_in_block * N_TILES_PER_WARP * 8 ;
125+ int warp_n_base = tile_n_base + n_warp * N_TILES_PER_WARP * 8 ;
117126
118127 // CuTE thread decomposition: t0 = lane%4 (0-3), t1 = lane/4 (0-7)
119128 int t0 = lane_id % 4 ;
@@ -431,14 +440,14 @@ extern "C" void cgemm_nvfp4(
431440 const unsigned char * A, const unsigned char * B, const unsigned char * SFA, const unsigned char * SFB, float * D, int M,
432441 int N, int K
433442) {
434- // Block tile: m16 x n(WARPS_PER_BLOCK * N_TILES_PER_WARP * 8)
435- const int BLOCK_N = WARPS_PER_BLOCK * N_TILES_PER_WARP * 8 ; // 128
443+ const int BLOCK_M = M_WARPS * 16 ;
444+ const int BLOCK_N = N_WARPS * N_TILES_PER_WARP * 8 ;
436445
437- int num_m_tiles = (M + 15 ) / 16 ;
446+ int num_m_blocks = (M + BLOCK_M - 1 ) / BLOCK_M ;
438447 int num_n_blocks = (N + BLOCK_N - 1 ) / BLOCK_N;
439448
440- dim3 grid (num_n_blocks, num_m_tiles );
441- int threads_per_block = WARPS_PER_BLOCK * 32 ; // 128
449+ dim3 grid (num_n_blocks, num_m_blocks );
450+ int threads_per_block = WARPS_PER_BLOCK * 32 ; // 256
442451
443452 kGemmNVFP4_opt <<<grid, threads_per_block>>> (A, B, SFA, SFB, D, M, N, K);
444453}
0 commit comments