Skip to content

Commit 4879c3f

Browse files
committed
Using the hardware specs to derive the cap
1 parent 1a31173 commit 4879c3f

2 files changed

Lines changed: 36 additions & 18 deletions

File tree

helion/autotuner/config_spec.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,12 +1051,16 @@ def raise_grid_block_minimums(self) -> None:
10511051
def lower_max_for_imbalanced_grid_dims(self) -> None:
10521052
"""Lower max_size for the large grid dimension when the shape is skinny.
10531053
1054-
When one grid dimension is much larger than the other (e.g. M=1024, N=8192),
1055-
the autotuner samples large tile sizes for the big dimension that produce too
1056-
few grid blocks for good GPU occupancy. Capping the larger dimension's tile
1057-
keeps the search in the useful region without hardcoding specific tile values.
1054+
Mirrors raise_grid_block_minimums: that method raises the minimum tile to
1055+
prevent too many blocks per dimension; this method lowers the maximum tile
1056+
to prevent too few blocks when one dimension is much larger than the other.
10581057
1059-
Only applied to 2-D grids where max(dim) >= 4 * min(dim).
1058+
The cap is derived from num_compute_units() so it naturally relaxes on
1059+
smaller GPUs (fewer SMs/CUs → fewer blocks needed → larger tiles OK).
1060+
1061+
Only applied to 2-D grids where max(dim) >= 8 * min(dim). The 8x threshold
1062+
is hardware-independent: below it the random sampler has a reasonable chance
1063+
of finding good balanced-tile configs on its own.
10601064
"""
10611065
if len(self.grid_block_ids) != 2:
10621066
return
@@ -1073,16 +1077,24 @@ def lower_max_for_imbalanced_grid_dims(self) -> None:
10731077

10741078
min_hint = min(hints)
10751079
max_hint = max(hints)
1076-
if max_hint < min_hint * 4:
1077-
return # Square-ish shape — leave the search space alone
1080+
if max_hint < min_hint * 8:
1081+
return # Not severely imbalanced — leave the search space alone
1082+
1083+
# Derive cap from actual hardware: require at least 1 block per compute unit
1084+
# per grid dimension (mirroring raise_grid_block_minimums which uses n_cus*64).
1085+
n_cus = num_compute_units()
1086+
n_dims = len(self.grid_block_ids)
1087+
min_blocks_per_dim = math.ceil(n_cus ** (1.0 / n_dims))
10781088

1079-
# Cap the larger dim's tile so that at least 4 blocks cover min_hint,
1080-
# keeping total blocks comparable across both dims.
1081-
# e.g. M=1024, N=8192: cap N-tile at max(64, 1024//4) = 256
1082-
cap = max(64, next_power_of_2(min_hint) // 2)
10831089
for spec, hint in zip(specs, hints, strict=True):
1084-
if hint == max_hint:
1085-
spec.update_max(min(spec.max_size, cap))
1090+
if hint != max_hint:
1091+
continue
1092+
max_tile = hint // min_blocks_per_dim
1093+
if max_tile < 2:
1094+
continue
1095+
# Round down to power of two
1096+
max_tile = 1 << (max_tile.bit_length() - 1)
1097+
spec.update_max(min(spec.max_size, max_tile))
10861098

10871099
def create_config_generation(
10881100
self,

test/test_best_available.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,11 @@ def _make_spec(self, m, n, m_max=None, n_max=None):
11001100
return config_spec
11011101

11021102
def test_skinny_n_caps_n_tile(self):
1103-
"""Skinny-N (M=1024, N=8192): N-tile max should be capped."""
1103+
"""Skinny-N (M=1024, N=8192): N-tile max should be capped (hardware-derived cap)."""
1104+
import math
1105+
1106+
from helion._compat import num_compute_units
1107+
11041108
spec = self._make_spec(m=1024, n=8192)
11051109
n_max_before = spec.block_sizes.block_id_lookup(1).max_size
11061110
spec.lower_max_for_imbalanced_grid_dims()
@@ -1110,8 +1114,10 @@ def test_skinny_n_caps_n_tile(self):
11101114
self.assertLess(n_max_after, n_max_before)
11111115
# M tile (the smaller dim) should be unchanged
11121116
self.assertEqual(m_max_after, spec.block_sizes.block_id_lookup(0).max_size)
1113-
# Cap should allow at least 4 blocks on N dim
1114-
self.assertGreaterEqual(8192 // n_max_after, 4)
1117+
# Cap is hardware-derived: at least min_blocks_per_dim blocks on N dim
1118+
n_cus = num_compute_units()
1119+
min_blocks_per_dim = math.ceil(n_cus**0.5)
1120+
self.assertGreaterEqual(8192 // n_max_after, min_blocks_per_dim)
11151121

11161122
def test_skinny_m_caps_m_tile(self):
11171123
"""Skinny-M (M=8192, N=1024): M-tile max should be capped."""
@@ -1131,8 +1137,8 @@ def test_square_shape_unchanged(self):
11311137
self.assertEqual(spec.block_sizes.block_id_lookup(1).max_size, n_max_before)
11321138

11331139
def test_slightly_imbalanced_unchanged(self):
1134-
"""3:1 ratio (M=1024, N=3072): below the 4x threshold, no change."""
1135-
spec = self._make_spec(m=1024, n=3072)
1140+
"""6.4x ratio (M=1280, N=8192): below the 8x threshold, no change (covers int4_gemm regression)."""
1141+
spec = self._make_spec(m=1280, n=8192)
11361142
n_max_before = spec.block_sizes.block_id_lookup(1).max_size
11371143
spec.lower_max_for_imbalanced_grid_dims()
11381144
self.assertEqual(spec.block_sizes.block_id_lookup(1).max_size, n_max_before)

0 commit comments

Comments
 (0)