Skip to content

Commit b32988f

Browse files
committed
[Scheduler] Don't use cluster instruction unless necessary
1 parent c75fb75 commit b32988f

1 file changed

Lines changed: 33 additions & 7 deletions

File tree

quack/tile_scheduler.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,12 @@ def create(
204204
ip=None,
205205
) -> "TileScheduler":
206206
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
207+
if const_expr(cute.size(params.cluster_shape_mnk, loc=loc, ip=ip) == 1):
208+
cluster_idx = cute.arch.block_idx()
209+
else:
210+
cluster_idx = cute.arch.cluster_idx()
207211
current_work_idx, _ = TileScheduler._cluster_idx_to_work_idx_batch(
208-
params, cute.arch.cluster_idx(), loc=loc, ip=ip
212+
params, cluster_idx, loc=loc, ip=ip
209213
)
210214
stages = 0
211215
if const_expr(
@@ -294,7 +298,9 @@ def _swizzle_cta(
294298
def _cluster_id_to_cta_id(
295299
self, cid_m: Int32, cid_n: Int32, *, block_zero_only: bool = False, loc=None, ip=None
296300
) -> Tuple[Int32, Int32]:
297-
if const_expr(block_zero_only):
301+
if const_expr(
302+
block_zero_only or cute.size(self.params.cluster_shape_mnk, loc=loc, ip=ip) == 1
303+
):
298304
bidx_in_cluster = (Int32(0), Int32(0))
299305
else:
300306
# Get the pid from cluster id
@@ -326,7 +332,11 @@ def _delinearize_work_idx(
326332
if is_valid:
327333
if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
328334
cluster_id_in_problem = work_idx
329-
_, _, bidz_ = cute.arch.cluster_idx()
335+
bidz_ = (
336+
cute.arch.block_idx()[2]
337+
if const_expr(cute.size(params.cluster_shape_mnk, loc=loc, ip=ip) == 1)
338+
else cute.arch.cluster_idx()[2]
339+
)
330340
else:
331341
bidz_, cluster_id_in_problem = divmod(work_idx, params.num_clusters_per_problem_fdd)
332342
if const_expr(bidz is not None):
@@ -380,7 +390,11 @@ def initial_work_tile_info(self, *, loc=None, ip=None) -> cutlass.utils.WorkTile
380390
def _fetch_next_work_idx(self, *, loc=None, ip=None) -> Int32 | Tuple[Int32, Int32, Boolean]:
381391
"""should only be called by the scheduler warp"""
382392
params = self.params
383-
num_persistent_clusters = cute.arch.cluster_dim()[2]
393+
num_persistent_clusters = (
394+
cute.arch.grid_dim()[2]
395+
if const_expr(cute.size(params.cluster_shape_mnk, loc=loc, ip=ip) == 1)
396+
else cute.arch.cluster_dim()[2]
397+
)
384398
if const_expr(params.persistence_mode == PersistenceMode.STATIC):
385399
return self._current_work_idx + num_persistent_clusters
386400
# Serpentine: alternate wave direction for a bit better load balancing
@@ -641,8 +655,12 @@ def create(
641655
loc=None,
642656
ip=None,
643657
) -> "TriangularTileScheduler":
658+
if const_expr(cute.size(params.cluster_shape_mnk, loc=loc, ip=ip) == 1):
659+
cluster_idx = cute.arch.block_idx()
660+
else:
661+
cluster_idx = cute.arch.cluster_idx()
644662
current_work_idx, _ = TileScheduler._cluster_idx_to_work_idx_batch(
645-
params, cute.arch.cluster_idx(), loc=loc, ip=ip
663+
params, cluster_idx, loc=loc, ip=ip
646664
)
647665
stages = 0
648666
if const_expr(
@@ -762,7 +780,11 @@ def _delinearize_work_idx(
762780
if is_valid:
763781
if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
764782
cluster_id_in_problem = work_idx
765-
_, _, bidz_ = cute.arch.cluster_idx()
783+
bidz_ = (
784+
cute.arch.block_idx()[2]
785+
if const_expr(cute.size(params.cluster_shape_mnk, loc=loc, ip=ip) == 1)
786+
else cute.arch.cluster_idx()[2]
787+
)
766788
else:
767789
bidz_, cluster_id_in_problem = divmod(work_idx, params.num_clusters_per_problem_fdd)
768790
cluster_id_in_problem = Int32(cluster_id_in_problem) # divmod returns IntValue
@@ -917,8 +939,12 @@ def create(
917939
loc=None,
918940
ip=None,
919941
) -> "VarlenMTileScheduler":
942+
if const_expr(cute.size(params.cluster_shape_mnk, loc=loc, ip=ip) == 1):
943+
cluster_idx = cute.arch.block_idx()
944+
else:
945+
cluster_idx = cute.arch.cluster_idx()
920946
current_work_idx, _ = VarlenMTileScheduler._cluster_idx_to_work_idx_batch(
921-
params, cute.arch.cluster_idx(), loc=loc, ip=ip
947+
params, cluster_idx, loc=loc, ip=ip
922948
)
923949
stages = 0
924950
if const_expr(

0 commit comments

Comments
 (0)