Skip to content

Commit 4fc1d8e

Browse files
TimDettmersclaude
andcommitted
Increase k_splits occupancy targets on datacenter GPUs
- MMA launcher: increase TARGET_BLOCKS_PER_SM from 4→6 (TN=64) and 1→2 (TN=128) when num_sms > 130 (H100 has 132 SMs). - Grouped MMA launcher: same occupancy target adjustment. - Detection is runtime (num_sms > 130), not compile-time, so consumer GPUs with ≤130 SMs keep the existing targets unchanged. H100 SXM benchmark improvement (CUDA graph, 500 iters × 5 trials): - gateup k=4 M=1: 30.6→28.8 µs (-5.9%) - gateup k=4 M=16: 34.9→29.5 µs (-15.5%) - Q k=2 M=1: 23.0→21.4 µs (-7.0%) - O k=4 M=1: 26.4→24.1 µs (-8.7%) - KV: neutral (small shape, already fully occupied) Consumer regression: zero (174/174 tests pass, RTX 4090 has 128 SMs so the higher targets are never activated). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2cc0224 commit 4fc1d8e

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

csrc/ops.cu

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,11 +1449,15 @@ static void kbitGemmProdLaunch(
14491449
int mn_tiles = m_tiles * n_tiles;
14501450

14511451
// k_splits heuristic: target enough blocks for good SM occupancy.
1452-
// With BLOCK_DIM threads/block, we want ~4 blocks/SM for latency hiding.
1453-
// BLOCK_DIM=128 (TN=64): 4 blocks/SM → 16 warps → 33% occupancy
1454-
// BLOCK_DIM=256 (TN=128): 1 block/SM → 8 warps → 16% occupancy (ok for large M)
1455-
constexpr int TARGET_BLOCKS_PER_SM = (BLOCK_DIM <= 128) ? 4 : 1;
1456-
int target_blocks = num_sms * TARGET_BLOCKS_PER_SM;
1452+
// Datacenter GPUs (H100) have higher bandwidth and can sustain more concurrent blocks.
1453+
// TN=64: 4 blocks/SM (consumer), 6 blocks/SM (datacenter) for better latency hiding
1454+
// TN=128: 1 block/SM (consumer), 2 blocks/SM (datacenter) to exploit larger shmem
1455+
int target_blocks_per_sm;
1456+
if constexpr (BLOCK_DIM <= 128)
1457+
target_blocks_per_sm = (num_sms > 130) ? 6 : 4; // H100: 132 SMs
1458+
else
1459+
target_blocks_per_sm = (num_sms > 130) ? 2 : 1;
1460+
int target_blocks = num_sms * target_blocks_per_sm;
14571461

14581462
int k_splits = 1;
14591463
if (mn_tiles < target_blocks && k_tiles > 1) {
@@ -1924,8 +1928,12 @@ static void kbitGroupedGemmProdLaunch(
19241928
int mn_tiles = num_experts * m_tiles_per_expert * n_tiles;
19251929

19261930
// k_splits heuristic: target enough blocks for good SM occupancy
1927-
constexpr int TARGET_BLOCKS_PER_SM = (BLOCK_DIM <= 128) ? 4 : 1;
1928-
int target_blocks = num_sms * TARGET_BLOCKS_PER_SM;
1931+
int target_blocks_per_sm;
1932+
if constexpr (BLOCK_DIM <= 128)
1933+
target_blocks_per_sm = (num_sms > 130) ? 6 : 4;
1934+
else
1935+
target_blocks_per_sm = (num_sms > 130) ? 2 : 1;
1936+
int target_blocks = num_sms * target_blocks_per_sm;
19291937

19301938
int k_splits = 1;
19311939
if (mn_tiles < target_blocks && k_tiles > 1) {

0 commit comments

Comments
 (0)