Skip to content

Commit 940ac12

Browse files
TimDettmersclaude
andcommitted
perf: Restore N_TILES=4, use launch_bounds(256,4) for occupancy
Target 4 blocks/SM to force compiler to reduce register count. N_TILES_PER_WARP=4 gives better compute/load ratio (4 MMA per A load). Block tile m32×n128 with 8 warps (2×4). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 543f7f8 commit 940ac12

1 file changed

Lines changed: 5 additions & 6 deletions

File tree

csrc/kernels_nvfp4_sm120.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,14 @@ __device__ __forceinline__ uint32_t pack_8_nibbles_slow(const unsigned char* dat
8787
// ============================================================================
8888

8989
// N-tiles per warp: each warp computes m16 x (N_TILES_PER_WARP * 8)
90-
#define N_TILES_PER_WARP 2
90+
#define N_TILES_PER_WARP 4
9191
// Block config: M_WARPS x N_WARPS warps per block
92-
// M_WARPS groups along M (each m16), N_WARPS groups along N (each handles N_TILES_PER_WARP n8-tiles)
93-
#define M_WARPS 4
92+
#define M_WARPS 2
9493
#define N_WARPS 4
95-
#define WARPS_PER_BLOCK (M_WARPS * N_WARPS) // 16
94+
#define WARPS_PER_BLOCK (M_WARPS * N_WARPS) // 8
9695

97-
// 512 threads, target 2 blocks/SM for good occupancy
98-
__global__ __launch_bounds__(WARPS_PER_BLOCK * 32, 2) void kGemmNVFP4_opt(
96+
// 256 threads, target 4 blocks/SM (limit regs to 48 via maxrregcount)
97+
__global__ __launch_bounds__(WARPS_PER_BLOCK * 32, 4) void kGemmNVFP4_opt(
9998
const unsigned char* __restrict__ A, // M x K/2 packed FP4 (row-major)
10099
const unsigned char* __restrict__ B, // N x K/2 packed FP4 (B transposed, row-major)
101100
const unsigned char* __restrict__ SFA, // M x K/16 UE4M3 scales

0 commit comments

Comments
 (0)