@@ -204,8 +204,12 @@ def create(
204204 ip = None ,
205205 ) -> "TileScheduler" :
206206 """is_scheduler_warp should only be true for one warp in the whole cluster"""
207+ if const_expr (cute .size (params .cluster_shape_mnk , loc = loc , ip = ip ) == 1 ):
208+ cluster_idx = cute .arch .block_idx ()
209+ else :
210+ cluster_idx = cute .arch .cluster_idx ()
207211 current_work_idx , _ = TileScheduler ._cluster_idx_to_work_idx_batch (
208- params , cute . arch . cluster_idx () , loc = loc , ip = ip
212+ params , cluster_idx , loc = loc , ip = ip
209213 )
210214 stages = 0
211215 if const_expr (
@@ -294,7 +298,9 @@ def _swizzle_cta(
294298 def _cluster_id_to_cta_id (
295299 self , cid_m : Int32 , cid_n : Int32 , * , block_zero_only : bool = False , loc = None , ip = None
296300 ) -> Tuple [Int32 , Int32 ]:
297- if const_expr (block_zero_only ):
301+ if const_expr (
302+ block_zero_only or cute .size (self .params .cluster_shape_mnk , loc = loc , ip = ip ) == 1
303+ ):
298304 bidx_in_cluster = (Int32 (0 ), Int32 (0 ))
299305 else :
300306 # Get the pid from cluster id
@@ -326,7 +332,11 @@ def _delinearize_work_idx(
326332 if is_valid :
327333 if const_expr (params .persistence_mode in [PersistenceMode .NONE , PersistenceMode .CLC ]):
328334 cluster_id_in_problem = work_idx
329- _ , _ , bidz_ = cute .arch .cluster_idx ()
335+ bidz_ = (
336+ cute .arch .block_idx ()[2 ]
337+ if const_expr (cute .size (params .cluster_shape_mnk , loc = loc , ip = ip ) == 1 )
338+ else cute .arch .cluster_idx ()[2 ]
339+ )
330340 else :
331341 bidz_ , cluster_id_in_problem = divmod (work_idx , params .num_clusters_per_problem_fdd )
332342 if const_expr (bidz is not None ):
@@ -380,7 +390,11 @@ def initial_work_tile_info(self, *, loc=None, ip=None) -> cutlass.utils.WorkTile
380390 def _fetch_next_work_idx (self , * , loc = None , ip = None ) -> Int32 | Tuple [Int32 , Int32 , Boolean ]:
381391 """should only be called by the scheduler warp"""
382392 params = self .params
383- num_persistent_clusters = cute .arch .cluster_dim ()[2 ]
393+ num_persistent_clusters = (
394+ cute .arch .grid_dim ()[2 ]
395+ if const_expr (cute .size (params .cluster_shape_mnk , loc = loc , ip = ip ) == 1 )
396+ else cute .arch .cluster_dim ()[2 ]
397+ )
384398 if const_expr (params .persistence_mode == PersistenceMode .STATIC ):
385399 return self ._current_work_idx + num_persistent_clusters
386400 # Serpentine: alternate wave direction for a bit better load balancing
@@ -641,8 +655,12 @@ def create(
641655 loc = None ,
642656 ip = None ,
643657 ) -> "TriangularTileScheduler" :
658+ if const_expr (cute .size (params .cluster_shape_mnk , loc = loc , ip = ip ) == 1 ):
659+ cluster_idx = cute .arch .block_idx ()
660+ else :
661+ cluster_idx = cute .arch .cluster_idx ()
644662 current_work_idx , _ = TileScheduler ._cluster_idx_to_work_idx_batch (
645- params , cute . arch . cluster_idx () , loc = loc , ip = ip
663+ params , cluster_idx , loc = loc , ip = ip
646664 )
647665 stages = 0
648666 if const_expr (
@@ -762,7 +780,11 @@ def _delinearize_work_idx(
762780 if is_valid :
763781 if const_expr (params .persistence_mode in [PersistenceMode .NONE , PersistenceMode .CLC ]):
764782 cluster_id_in_problem = work_idx
765- _ , _ , bidz_ = cute .arch .cluster_idx ()
783+ bidz_ = (
784+ cute .arch .block_idx ()[2 ]
785+ if const_expr (cute .size (params .cluster_shape_mnk , loc = loc , ip = ip ) == 1 )
786+ else cute .arch .cluster_idx ()[2 ]
787+ )
766788 else :
767789 bidz_ , cluster_id_in_problem = divmod (work_idx , params .num_clusters_per_problem_fdd )
768790 cluster_id_in_problem = Int32 (cluster_id_in_problem ) # divmod returns IntValue
@@ -917,8 +939,12 @@ def create(
917939 loc = None ,
918940 ip = None ,
919941 ) -> "VarlenMTileScheduler" :
942+ if const_expr (cute .size (params .cluster_shape_mnk , loc = loc , ip = ip ) == 1 ):
943+ cluster_idx = cute .arch .block_idx ()
944+ else :
945+ cluster_idx = cute .arch .cluster_idx ()
920946 current_work_idx , _ = VarlenMTileScheduler ._cluster_idx_to_work_idx_batch (
921- params , cute . arch . cluster_idx () , loc = loc , ip = ip
947+ params , cluster_idx , loc = loc , ip = ip
922948 )
923949 stages = 0
924950 if const_expr (
0 commit comments