Skip to content

Commit 53d9af8

Browse files
authored
Autotune between tma gather and cp.async in SM100 (Dao-AILab#100)
* autotune tma gather * fix comment * ruff format
1 parent d15e0b6 commit 53d9af8

4 files changed

Lines changed: 24 additions & 3 deletions

File tree

quack/gemm_act.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def _compile_gemm_act(
286286
gemm_cls_name,
287287
rounding_mode=RoundingMode.RN,
288288
sr_seed_mode=0,
289+
use_tma_gather=False,
289290
):
290291
sm_to_cls = {
291292
"act": {9: GemmActSm90, 10: GemmActSm100, 11: GemmActSm100, 12: GemmActSm120},
@@ -360,6 +361,7 @@ def fake_scalar(mode, dtype=Int32):
360361
epi_args,
361362
scheduler_args,
362363
varlen_args,
364+
use_tma_gather=use_tma_gather,
363365
)
364366

365367

@@ -385,6 +387,7 @@ def gemm_act(
385387
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
386388
rounding_mode: int = RoundingMode.RN,
387389
sr_seed: int | Tensor = 0,
390+
use_tma_gather: bool = False,
388391
) -> None:
389392
if activation in gate_fn_map:
390393
gemm_cls_name = "gated"
@@ -462,6 +465,7 @@ def gemm_act(
462465
gemm_cls_name,
463466
rounding_mode=rounding_mode,
464467
sr_seed_mode=sr_seed_mode,
468+
use_tma_gather=use_tma_gather,
465469
)
466470

467471
from quack.cache_utils import COMPILE_ONLY

quack/gemm_config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class GemmConfig:
1818
# raster_order: int = 1
1919
max_swizzle_size: int = 8
2020
device_capacity: int = 9
21+
# whether to use TMA gather (vs normal cp.async) for gather_A on SM100
22+
use_tma_gather: bool = False
2123

2224

2325
def _get_sm90_configs(
@@ -58,6 +60,7 @@ def _get_sm90_configs(
5860
swap_ab=swap_ab,
5961
device_capacity=9,
6062
is_dynamic_persistent=False, # default to not use dynamic persistent on SM90
63+
use_tma_gather=False, # TMA gather not supported on SM90
6164
)
6265
for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
6366
tile_mn_vals,
@@ -87,6 +90,7 @@ def _get_sm100_configs(
8790
GemmConfig, pingpong=False, device_capacity=10
8891
) # There's no pingpong on Sm100
8992
use_clc_vals = [True, False]
93+
use_tma_gather_vals = [True, False]
9094
return [
9195
GemmConfigCls(
9296
tile_m=m,
@@ -96,9 +100,10 @@ def _get_sm100_configs(
96100
swap_ab=sab,
97101
max_swizzle_size=8,
98102
is_dynamic_persistent=use_clc,
103+
use_tma_gather=use_tma_gather,
99104
)
100-
for (m, n, (cm, cn)), sab, use_clc in itertools.product(
101-
tile_mn_cluster_vals, swap_ab_vals, use_clc_vals
105+
for (m, n, (cm, cn)), sab, use_clc, use_tma_gather in itertools.product(
106+
tile_mn_cluster_vals, swap_ab_vals, use_clc_vals, use_tma_gather_vals
102107
)
103108
]
104109

@@ -126,6 +131,7 @@ def _get_sm120_configs(
126131
swap_ab=swap_ab,
127132
device_capacity=12,
128133
is_dynamic_persistent=True,
134+
use_tma_gather=False, # TMA gather not supported on SM120
129135
)
130136
for (tile_m, tile_n, pingpong), swap_ab in itertools.product(tile_mn_vals, swap_ab_vals)
131137
]

quack/gemm_dact.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def _compile_gemm_dact(
259259
gather_A,
260260
device_capacity,
261261
gemm_cls_name,
262+
use_tma_gather=False,
262263
):
263264
is_dgated = gemm_cls_name == "dgated"
264265
sm_to_cls = {
@@ -352,6 +353,7 @@ def _set_implicit_dtype(gemm_obj):
352353
scheduler_args,
353354
varlen_args,
354355
post_init=post_init,
356+
use_tma_gather=use_tma_gather,
355357
)
356358

357359

@@ -376,7 +378,7 @@ def gemm_dact(
376378
colvec_reduce: Optional[Tensor] = None,
377379
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
378380
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
379-
use_clc_persistence: bool = False,
381+
use_tma_gather: bool = False,
380382
) -> None:
381383
is_dgated = activation in dgate_fn_map
382384
if not is_dgated:
@@ -463,6 +465,7 @@ def gemm_dact(
463465
gather_A,
464466
device_capacity,
465467
gemm_cls_name,
468+
use_tma_gather=use_tma_gather,
466469
)
467470

468471
from quack.cache_utils import COMPILE_ONLY

quack/gemm_interface.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
114114
if device_capacity == 9:
115115
configs = [conf for conf in configs if conf.kwargs["config"].tile_n != 208]
116116
configs = [conf for conf in configs if not conf.kwargs["config"].is_dynamic_persistent]
117+
# use_tma_gather only valid when gather_A is active on SM100/SM110
118+
if not gather_A or device_capacity not in [10, 11]:
119+
configs = [conf for conf in configs if not conf.kwargs["config"].use_tma_gather]
117120
return configs
118121

119122

@@ -215,6 +218,7 @@ def gemm_tuned(
215218
add_to_output=add_to_output,
216219
rounding_mode=rounding_mode,
217220
sr_seed=sr_seed,
221+
use_tma_gather=config.use_tma_gather,
218222
)
219223

220224

@@ -286,6 +290,7 @@ def gemm_act_tuned(
286290
colvec_bias=bias if config.swap_ab else None,
287291
cu_seqlens_m=cu_seqlens_m,
288292
A_idx=A_idx,
293+
use_tma_gather=config.use_tma_gather,
289294
)
290295

291296

@@ -351,6 +356,7 @@ def gemm_dact_tuned(
351356
max_swizzle_size=config.max_swizzle_size,
352357
cu_seqlens_m=cu_seqlens_m,
353358
A_idx=A_idx,
359+
use_tma_gather=config.use_tma_gather,
354360
)
355361

356362

@@ -1198,6 +1204,7 @@ def gemm_gated_tuned(
11981204
colvec_bias=bias if config.swap_ab else None,
11991205
cu_seqlens_m=cu_seqlens_m,
12001206
A_idx=A_idx,
1207+
use_tma_gather=config.use_tma_gather,
12011208
)
12021209

12031210

@@ -1292,6 +1299,7 @@ def gemm_dgated_tuned(
12921299
colvec_reduce=colvec_reduce_partial,
12931300
cu_seqlens_m=cu_seqlens_m,
12941301
A_idx=A_idx,
1302+
use_tma_gather=config.use_tma_gather,
12951303
)
12961304
if colvec_reduce:
12971305
colvec_reduce_final = colvec_reduce_partial.sum(dim=-1)

0 commit comments

Comments
 (0)