Skip to content

Commit 0de3a74

Browse files
TimDettmersclaude
andcommitted
perf: Increase GEMM block size to 8 warps (m32×n128)
NCU profiling showed 5.42 active warps/cycle (low occupancy). Increase from 4 to 8 warps per block with 2D warp mapping (2 M-warps × 4 N-warps). Block tile now m32×n128. Added __launch_bounds__ for register pressure control. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3230e4c commit 0de3a74

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

csrc/kernels_nvfp4_sm120.cu

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,32 +88,41 @@ __device__ __forceinline__ uint32_t pack_8_nibbles_slow(const unsigned char* dat
8888

8989
// N-tiles per warp: each warp computes m16 x (N_TILE_PER_WARP * 8)
9090
#define N_TILES_PER_WARP 4
91-
#define WARPS_PER_BLOCK 4
91+
// Block config: M_WARPS x N_WARPS warps per block
92+
// M_WARPS groups along M (each m16), N_WARPS groups along N (each handles N_TILES_PER_WARP n8-tiles)
93+
#define M_WARPS 2
94+
#define N_WARPS 4
95+
#define WARPS_PER_BLOCK (M_WARPS * N_WARPS) // 8
9296

93-
__global__ void kGemmNVFP4_opt(
97+
__global__ __launch_bounds__(WARPS_PER_BLOCK * 32, 2) void kGemmNVFP4_opt(
9498
const unsigned char* __restrict__ A, // M x K/2 packed FP4 (row-major)
9599
const unsigned char* __restrict__ B, // N x K/2 packed FP4 (B transposed, row-major)
96100
const unsigned char* __restrict__ SFA, // M x K/16 UE4M3 scales
97101
const unsigned char* __restrict__ SFB, // N x K/16 UE4M3 scales
98102
float* __restrict__ D, // M x N output (F32)
99103
int M, int N, int K
100104
) {
101-
// Block tile: m16 x n(WARPS_PER_BLOCK * N_TILES_PER_WARP * 8)
102-
// = m16 x n128 for 4 warps with 4 n-tiles each
103-
const int BLOCK_N = WARPS_PER_BLOCK * N_TILES_PER_WARP * 8;
105+
// Block tile: m(M_WARPS*16) x n(N_WARPS * N_TILES_PER_WARP * 8)
106+
// = m32 x n128 for 2x4 warps
107+
const int BLOCK_M = M_WARPS * 16;
108+
const int BLOCK_N = N_WARPS * N_TILES_PER_WARP * 8;
104109

105110
int warp_in_block = threadIdx.x / 32;
106111
int lane_id = threadIdx.x % 32;
107112

113+
// 2D warp mapping: m_warp along M, n_warp along N
114+
int m_warp = warp_in_block / N_WARPS; // 0..(M_WARPS-1)
115+
int n_warp = warp_in_block % N_WARPS; // 0..(N_WARPS-1)
116+
108117
// Block-level tile position
109-
int tile_m = blockIdx.y * 16;
118+
int tile_m = blockIdx.y * BLOCK_M + m_warp * 16;
110119
int tile_n_base = blockIdx.x * BLOCK_N;
111120

112121
if (tile_m >= M)
113122
return;
114123

115124
// This warp's N offset within the block
116-
int warp_n_base = tile_n_base + warp_in_block * N_TILES_PER_WARP * 8;
125+
int warp_n_base = tile_n_base + n_warp * N_TILES_PER_WARP * 8;
117126

118127
// CuTE thread decomposition: t0 = lane%4 (0-3), t1 = lane/4 (0-7)
119128
int t0 = lane_id % 4;
@@ -431,14 +440,14 @@ extern "C" void cgemm_nvfp4(
431440
const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, float* D, int M,
432441
int N, int K
433442
) {
434-
// Block tile: m16 x n(WARPS_PER_BLOCK * N_TILES_PER_WARP * 8)
435-
const int BLOCK_N = WARPS_PER_BLOCK * N_TILES_PER_WARP * 8; // 128
443+
const int BLOCK_M = M_WARPS * 16;
444+
const int BLOCK_N = N_WARPS * N_TILES_PER_WARP * 8;
436445

437-
int num_m_tiles = (M + 15) / 16;
446+
int num_m_blocks = (M + BLOCK_M - 1) / BLOCK_M;
438447
int num_n_blocks = (N + BLOCK_N - 1) / BLOCK_N;
439448

440-
dim3 grid(num_n_blocks, num_m_tiles);
441-
int threads_per_block = WARPS_PER_BLOCK * 32; // 128
449+
dim3 grid(num_n_blocks, num_m_blocks);
450+
int threads_per_block = WARPS_PER_BLOCK * 32; // 256
442451

443452
kGemmNVFP4_opt<<<grid, threads_per_block>>>(A, B, SFA, SFB, D, M, N, K);
444453
}

0 commit comments

Comments
 (0)