@@ -1051,12 +1051,16 @@ def raise_grid_block_minimums(self) -> None:
10511051 def lower_max_for_imbalanced_grid_dims (self ) -> None :
10521052 """Lower max_size for the large grid dimension when the shape is skinny.
10531053
1054- When one grid dimension is much larger than the other (e.g. M=1024, N=8192),
1055- the autotuner samples large tile sizes for the big dimension that produce too
1056- few grid blocks for good GPU occupancy. Capping the larger dimension's tile
1057- keeps the search in the useful region without hardcoding specific tile values.
1054+ Mirrors raise_grid_block_minimums: that method raises the minimum tile to
1055+ prevent too many blocks per dimension; this method lowers the maximum tile
1056+ to prevent too few blocks when one dimension is much larger than the other.
10581057
1059- Only applied to 2-D grids where max(dim) >= 4 * min(dim).
1058+ The cap is derived from num_compute_units() so it naturally relaxes on
1059+ smaller GPUs (fewer SMs/CUs → fewer blocks needed → larger tiles OK).
1060+
1061+ Only applied to 2-D grids where max(dim) >= 8 * min(dim). The 8x threshold
1062+ is hardware-independent: below it the random sampler has a reasonable chance
1063+ of finding good balanced-tile configs on its own.
10601064 """
10611065 if len (self .grid_block_ids ) != 2 :
10621066 return
@@ -1073,16 +1077,24 @@ def lower_max_for_imbalanced_grid_dims(self) -> None:
10731077
10741078 min_hint = min (hints )
10751079 max_hint = max (hints )
1076- if max_hint < min_hint * 4 :
1077- return # Square-ish shape — leave the search space alone
1080+ if max_hint < min_hint * 8 :
1081+ return # Not severely imbalanced — leave the search space alone
1082+
1083+ # Derive cap from actual hardware: require at least 1 block per compute unit
1084+ # per grid dimension (mirroring raise_grid_block_minimums which uses n_cus*64).
1085+ n_cus = num_compute_units ()
1086+ n_dims = len (self .grid_block_ids )
1087+ min_blocks_per_dim = math .ceil (n_cus ** (1.0 / n_dims ))
10781088
1079- # Cap the larger dim's tile so that at least 4 blocks cover min_hint,
1080- # keeping total blocks comparable across both dims.
1081- # e.g. M=1024, N=8192: cap N-tile at max(64, 1024//4) = 256
1082- cap = max (64 , next_power_of_2 (min_hint ) // 2 )
10831089 for spec , hint in zip (specs , hints , strict = True ):
1084- if hint == max_hint :
1085- spec .update_max (min (spec .max_size , cap ))
1090+ if hint != max_hint :
1091+ continue
1092+ max_tile = hint // min_blocks_per_dim
1093+ if max_tile < 2 :
1094+ continue
1095+ # Round down to power of two
1096+ max_tile = 1 << (max_tile .bit_length () - 1 )
1097+ spec .update_max (min (spec .max_size , max_tile ))
10861098
10871099 def create_config_generation (
10881100 self ,
0 commit comments