Skip to content

Commit 543f7f8

Browse files
TimDettmersclaude
andcommitted
perf: Reduce register pressure (N_TILES=2, 16 warps/block)
NCU showed 80 regs/thread from N_TILES_PER_WARP=4 accumulators. Reduce to 2 N-tiles per warp (8 acc floats instead of 16), increase to 4×4=16 warps per block (m64×n64 tile, 512 threads). Target 2 blocks/SM for better occupancy. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0de3a74 commit 543f7f8

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

csrc/kernels_nvfp4_sm120.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,15 @@ __device__ __forceinline__ uint32_t pack_8_nibbles_slow(const unsigned char* dat
8686
// Shared memory: A tile (512 bytes) + SFA (64 bytes) per K-step
8787
// ============================================================================
8888

89-
// N-tiles per warp: each warp computes m16 x (N_TILE_PER_WARP * 8)
90-
#define N_TILES_PER_WARP 4
89+
// N-tiles per warp: each warp computes m16 x (N_TILES_PER_WARP * 8)
90+
#define N_TILES_PER_WARP 2
9191
// Block config: M_WARPS x N_WARPS warps per block
9292
// M_WARPS groups along M (each m16), N_WARPS groups along N (each handles N_TILES_PER_WARP n8-tiles)
93-
#define M_WARPS 2
93+
#define M_WARPS 4
9494
#define N_WARPS 4
95-
#define WARPS_PER_BLOCK (M_WARPS * N_WARPS) // 8
95+
#define WARPS_PER_BLOCK (M_WARPS * N_WARPS) // 16
9696

97+
// 512 threads, target 2 blocks/SM for good occupancy
9798
__global__ __launch_bounds__(WARPS_PER_BLOCK * 32, 2) void kGemmNVFP4_opt(
9899
const unsigned char* __restrict__ A, // M x K/2 packed FP4 (row-major)
99100
const unsigned char* __restrict__ B, // N x K/2 packed FP4 (B transposed, row-major)

0 commit comments

Comments
 (0)