Skip to content

Commit 15f1c41

Browse files
committed
[Autotuner] Rebenchmark pool configs on owner workers
stack-info: PR: #2295, branch: choijon5/stack/52
1 parent a23c029 commit 15f1c41

2 files changed

Lines changed: 26 additions & 25 deletions

File tree

helion/autotuner/benchmark_provider.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,14 @@ def rebenchmark(
12151215
if self._precompile_args_path is None:
12161216
return None
12171217

1218+
if self._pool_manager is not None:
1219+
return self.benchmark_isolated(
1220+
fns,
1221+
warmup=25,
1222+
rep=100,
1223+
desc=desc,
1224+
)
1225+
12181226
fn_specs: list[SerializedCompiledFunction] = []
12191227
for fn in fns:
12201228
fn_spec = self._serialize_fn_for_worker(cast("CompiledConfig", fn))
@@ -1298,9 +1306,7 @@ def benchmark_isolated(
12981306
if match_unrecoverable_runtime_error(e):
12991307
self.log.warning(f"{desc} sticky CUDA error skipped: {e}")
13001308
else:
1301-
self.log.debug(
1302-
f"{desc} subprocess raised: {type(e).__name__}: {e}"
1303-
)
1309+
self.log.debug(f"{desc} subprocess raised: {type(e).__name__}: {e}")
13041310
self._autotune_metrics.num_compile_failures += 1
13051311
timing = inf
13061312
timings.append(float(timing))

test/test_benchmark_worker.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from helion.autotuner.base_search import PopulationBasedSearch
2828
from helion.autotuner.base_search import PopulationMember
2929
from helion.autotuner.benchmark_job import BenchmarkJob
30-
from helion.autotuner.benchmark_job import RebenchmarkJob
3130
from helion.autotuner.benchmark_job import _load_args
3231
from helion.autotuner.benchmark_pool import PoolBenchmarkManager
3332
from helion.autotuner.benchmark_provider import LocalBenchmarkProvider
@@ -306,27 +305,23 @@ def fake_fn_b() -> None:
306305
self.assertEqual(job.warmup, 25)
307306
self.assertEqual(job.rep, 100)
308307

309-
def test_rebenchmark_uses_worker_pool(self) -> None:
310-
# Full-effort rebenchmarking should run on the worker that precompiled the config.
308+
def test_rebenchmark_uses_owner_isolated_worker_pool(self) -> None:
309+
# Pool rebenchmarking should isolate each candidate on its owner worker.
311310
class FakePoolManager:
312311
def __init__(self) -> None:
313-
self.worker_index: int | None = None
314-
self.job: object | None = None
315-
self.timeout: float | None = None
312+
self.calls: list[tuple[int, object, float]] = []
316313

317-
def worker_index_for_fn(self, _fn: object) -> int:
318-
return 3
314+
def worker_index_for_fn(self, fn: object) -> int:
315+
return 3 if fn is fake_fn_a else 4
319316

320317
def run_job_on_worker(
321318
self,
322319
worker_index: int,
323320
job: object,
324321
timeout: float,
325-
) -> list[float]:
326-
self.worker_index = worker_index
327-
self.job = job
328-
self.timeout = timeout
329-
return [1.0, 2.0]
322+
) -> float:
323+
self.calls.append((worker_index, job, timeout))
324+
return float(worker_index)
330325

331326
class FakeLog:
332327
def warning(self, *_args: object, **_kwargs: object) -> None:
@@ -335,7 +330,10 @@ def warning(self, *_args: object, **_kwargs: object) -> None:
335330
def debug(self, *_args: object, **_kwargs: object) -> None:
336331
pass
337332

338-
def fake_fn() -> None:
333+
def fake_fn_a() -> None:
334+
pass
335+
336+
def fake_fn_b() -> None:
339337
pass
340338

341339
pool = FakePoolManager()
@@ -358,15 +356,12 @@ def fake_fn() -> None:
358356
module_name=None,
359357
)
360358

361-
result = provider.rebenchmark([fake_fn, fake_fn], repeat=7, desc="verify")
359+
result = provider.rebenchmark([fake_fn_a, fake_fn_b], repeat=7, desc="verify")
362360

363-
self.assertEqual(result, [1.0, 2.0])
364-
self.assertEqual(pool.worker_index, 3)
365-
self.assertIsInstance(pool.job, RebenchmarkJob)
366-
assert isinstance(pool.job, RebenchmarkJob)
367-
self.assertEqual(pool.job.repeat, 7)
368-
self.assertEqual(len(pool.job.fn_specs), 2)
369-
self.assertEqual(pool.timeout, 20.0)
361+
self.assertEqual(result, [3.0, 4.0])
362+
self.assertEqual([call[0] for call in pool.calls], [3, 4])
363+
self.assertTrue(all(isinstance(call[1], BenchmarkJob) for call in pool.calls))
364+
self.assertTrue(all(call[2] == 10.0 for call in pool.calls))
370365

371366
def test_population_rebenchmark_uses_provider_timings(self) -> None:
372367
# BaseSearch should use provider rebenchmark timings when available.

0 commit comments

Comments
 (0)