Skip to content

Commit 8496337

Browse files
committed
[Autotuner] Add isolated worker benchmarking
stack-info: PR: #2294, branch: choijon5/stack/51
1 parent 43211f8 commit 8496337

2 files changed

Lines changed: 139 additions & 0 deletions

File tree

helion/autotuner/benchmark_provider.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,22 @@ def rebenchmark(
299299
"""
300300
return None
301301

302+
def benchmark_isolated(
303+
self,
304+
fns: list[Callable[..., object]],
305+
*,
306+
warmup: int,
307+
rep: int,
308+
desc: str = "Benchmarking",
309+
) -> list[float] | None:
310+
"""Optionally benchmark callables independently via the provider.
311+
312+
Returning ``None`` means the caller should use its default in-process
313+
path. Returning a list gives one timing per callable; isolated failures
314+
should be represented as ``inf`` timings.
315+
"""
316+
return None
317+
302318
@abc.abstractmethod
303319
def setup(self) -> None:
304320
"""Prepare resources needed before benchmarking begins (e.g. tmpdir)."""
@@ -1234,3 +1250,58 @@ def rebenchmark(
12341250
self.log.debug(f"{desc} subprocess raised: {type(e).__name__}: {e}")
12351251
self._autotune_metrics.num_compile_failures += len(fn_specs)
12361252
return [inf] * len(fn_specs)
1253+
1254+
def benchmark_isolated(
1255+
self,
1256+
fns: list[Callable[..., object]],
1257+
*,
1258+
warmup: int,
1259+
rep: int,
1260+
desc: str = "Benchmarking",
1261+
) -> list[float] | None:
1262+
if not self._subprocess_benchmark_enabled():
1263+
return None
1264+
if self.settings.autotune_benchmark_fn is not None:
1265+
return None
1266+
if self._precompile_args_path is None:
1267+
return None
1268+
1269+
timings: list[float] = []
1270+
timeout = float(self.settings.autotune_benchmark_timeout)
1271+
for fn in fns:
1272+
fn_spec = self._serialize_fn_for_worker(cast("CompiledConfig", fn))
1273+
if fn_spec is None:
1274+
return None
1275+
job = BenchmarkJob(
1276+
fn_spec=fn_spec,
1277+
args_path=self._precompile_args_path,
1278+
warmup=warmup,
1279+
rep=rep,
1280+
)
1281+
worker_index = (
1282+
self._pool_manager.worker_index_for_fn(fn)
1283+
if self._pool_manager is not None
1284+
else 0
1285+
)
1286+
try:
1287+
timing = self._run_subprocess_job(
1288+
job,
1289+
timeout,
1290+
worker_index=worker_index,
1291+
)
1292+
except BenchmarkSubprocessError as e:
1293+
self.log.warning(f"{desc} subprocess failed: {e}")
1294+
self._autotune_metrics.num_compile_failures += 1
1295+
timing = inf
1296+
except Exception as e:
1297+
e.__traceback__ = None
1298+
if match_unrecoverable_runtime_error(e):
1299+
self.log.warning(f"{desc} sticky CUDA error skipped: {e}")
1300+
else:
1301+
self.log.debug(
1302+
f"{desc} subprocess raised: {type(e).__name__}: {e}"
1303+
)
1304+
self._autotune_metrics.num_compile_failures += 1
1305+
timing = inf
1306+
timings.append(float(timing))
1307+
return timings

test/test_benchmark_worker.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from helion._testing import skipIfXPU
2727
from helion.autotuner.base_search import PopulationBasedSearch
2828
from helion.autotuner.base_search import PopulationMember
29+
from helion.autotuner.benchmark_job import BenchmarkJob
2930
from helion.autotuner.benchmark_job import RebenchmarkJob
3031
from helion.autotuner.benchmark_job import _load_args
3132
from helion.autotuner.benchmark_pool import PoolBenchmarkManager
@@ -238,6 +239,73 @@ def test_pool_mode_reports_disabled_worker_reason(self) -> None:
238239
assert reason is not None
239240
self.assertIn("disabled", reason)
240241

242+
def test_benchmark_isolated_routes_each_fn_to_owner_worker(self) -> None:
243+
# Isolated benchmarking should route each config to its precompile owner.
244+
class FakePoolManager:
245+
def __init__(self) -> None:
246+
self.calls: list[tuple[int, object, float]] = []
247+
248+
def worker_index_for_fn(self, fn: object) -> int:
249+
return 2 if fn is fake_fn_a else 1
250+
251+
def run_job_on_worker(
252+
self,
253+
worker_index: int,
254+
job: object,
255+
timeout: float,
256+
) -> float:
257+
self.calls.append((worker_index, job, timeout))
258+
return float(worker_index)
259+
260+
class FakeLog:
261+
def warning(self, *_args: object, **_kwargs: object) -> None:
262+
pass
263+
264+
def debug(self, *_args: object, **_kwargs: object) -> None:
265+
pass
266+
267+
def fake_fn_a() -> None:
268+
pass
269+
270+
def fake_fn_b() -> None:
271+
pass
272+
273+
pool = FakePoolManager()
274+
provider = cast("Any", LocalBenchmarkProvider.__new__(LocalBenchmarkProvider))
275+
provider.settings = Settings(
276+
autotune_precompile="pool",
277+
autotune_benchmark_timeout=10,
278+
)
279+
provider.config_spec = SimpleNamespace(backend=None)
280+
provider.mutated_arg_indices = []
281+
provider._precompile_args_path = "args.pt"
282+
provider._pool_manager = pool
283+
provider._benchmark_worker = None
284+
provider._autotune_metrics = SimpleNamespace(num_compile_failures=0)
285+
provider.log = FakeLog()
286+
provider._serialize_fn_for_worker = lambda _fn: SerializedCompiledFunction(
287+
function_name="fake_fn",
288+
source_code="def fake_fn(): pass",
289+
filename=None,
290+
module_name=None,
291+
)
292+
293+
result = provider.benchmark_isolated(
294+
[fake_fn_a, fake_fn_b],
295+
warmup=25,
296+
rep=100,
297+
desc="verify",
298+
)
299+
300+
self.assertEqual(result, [2.0, 1.0])
301+
self.assertEqual([call[0] for call in pool.calls], [2, 1])
302+
self.assertTrue(all(isinstance(call[1], BenchmarkJob) for call in pool.calls))
303+
self.assertTrue(all(call[2] == 10.0 for call in pool.calls))
304+
for _, job, _ in pool.calls:
305+
assert isinstance(job, BenchmarkJob)
306+
self.assertEqual(job.warmup, 25)
307+
self.assertEqual(job.rep, 100)
308+
241309
def test_rebenchmark_uses_worker_pool(self) -> None:
242310
# Full-effort rebenchmarking should run on the worker that precompiled the config.
243311
class FakePoolManager:

0 commit comments

Comments
 (0)