Commit 11f4b32
Optimize MMA kernel for small M: TILE_N=64 + multi-block-per-SM k_splits
For M<=16, the MMA kernel now uses TILE_N=64 (4 warps, 128 threads) instead
of TILE_N=128 (8 warps, 256 threads). This doubles n_tiles for better SM
coverage. Combined with an aggressive k_splits heuristic targeting 4 blocks
per SM, occupancy jumps from 8% to ~28%.
Key changes:
- Template kbit_gemm_prod on TILE_N_VAL (default 128, use 64 for M<=16)
- Derive NUM_WARPS and COLS_PER_WARP from TILE_N instead of hardcoding
- k_splits heuristic targets 4 blocks/SM for TILE_N=64 (128-thread blocks)
- Python workspace allocation uses TILE_N=64 worst case for tile_counters
Benchmark (ncu, RTX 4090, dense_down K=5120 N=2048):
k=2 M=4: 10.34 us (vs scalar GEMV 17.82 us = 1.72x faster)
k=4 M=3: 12.93 us (vs scalar GEMV 16.58 us = 1.28x faster)
Also attempted dequant-to-shmem (Phase 2) but reverted — serializing
the full dequant pass before MMA eliminates pipeline interleaving,
resulting in 2.6x regression. Inline dequant is superior.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>1 parent 05638f5 commit 11f4b32
2 files changed
+147
-156
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1042 | 1042 | | |
1043 | 1043 | | |
1044 | 1044 | | |
1045 | | - | |
| 1045 | + | |
1046 | 1046 | | |
1047 | 1047 | | |
1048 | 1048 | | |
1049 | 1049 | | |
1050 | 1050 | | |
1051 | | - | |
1052 | | - | |
1053 | | - | |
| 1051 | + | |
| 1052 | + | |
1054 | 1053 | | |
1055 | | - | |
| 1054 | + | |
1056 | 1055 | | |
1057 | 1056 | | |
1058 | 1057 | | |
| |||
0 commit comments