Small Tile M BlockScaled GEMM + Grouped GEMM on SM12x#3196
Small Tile M BlockScaled GEMM + Grouped GEMM on SM12x#3196besquared wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
Mirror the small-N scale-tensor padding pattern from PR NVIDIA#3176 for SFA so M=64 cooperative blockscaled kernels use a padded TMA box, broadcasted producer coordinates, and sliced consumer layouts. Add dense, active GemmUniversalAdapter, and grouped coverage for SM120 NVFP4 small-tile shapes.
| // | ||
|
|
||
| auto [gA_mkl, gB_nkl, gSFA_mkl, gSFB_nkl] = load_inputs; | ||
| auto gA_mkl = get<0>(load_inputs); |
There was a problem hiding this comment.
Is there a reason that this needed to be changed to get?
There was a problem hiding this comment.
Fixed, just a stylistic regression, I had gpt5.5 sweep for other stylistic and idiomatic regressions and things seemed ok
| // | ||
| // Compute on k_tile | ||
| // | ||
| if constexpr (SingleCtaKBlock) { |
There was a problem hiding this comment.
For clarification is this just intended as a constexpr branch for TileShapeK == MMA_K?
There was a problem hiding this comment.
I updated this so that it conditions on the implication of the tile shape (K_BLOCK_MAX=1) rather than on the tile shape itself, which hopefully clarifies the semantics.
|
Would it also be possible to add the tile sizes to |
I mirrored the #3176 generator pattern and added the new small-M shapes to the SM120 blockscaled generator lists. I also added the analogous grouped pingpong smoke test for M=64: SM120_Device_Gemm_e2m1t_e2m1n_e2m1t_tensorop_f32_epilogue_VS16_group_pingpong.row_sf_64x128x128 That gives representative cooperative small-tile coverage plus a grouped pingpong smoke test for the generated pingpong schedule exactly as in #3176. |
Adds BlockScaled GEMM and Grouped GEMM support for small tile M = 64 and K = 64 on RTX / DGX Spark GPUs targeting SM120 and SM121. Completes the symmetric counterpart to #3176 (small N).
Motivation: fused attention. NVFP4 attention kernels at head_dim=256 benefit from M=64 because it enables 4-thread-per-row softmax distribution with 8 mma warps (vs 2 threads/row at M=128), removing warp imbalance during online softmax production. K=64 reduces the QK gemm's per-stage scale tensor footprint, freeing smem budget for the attention scaffold's score and output staging.
Changes:
with consumer-side slice partitioning.
Verified on RTX 6000 Pro Blackwell (sm_120):
Note on scope: M = 32 is not included. The mma.sync NVFP4 atom is m16n8k64; the m = 16 dimension forces M = 64 as the minimum tile that keeps all 4 cooperative mma warps engaged with at least one atom each. M = 32 would leave 2 of 4 warps idle and is structurally degenerate. #3176's small-N support could go to N = 32 because n = 8 in the same atom permits full warp utilization at the smaller dim. The asymmetry between m and n in this PR's small-tile coverage reflects the asymmetry in the underlying mma atom, not a scope cut.