2727from helion .autotuner .base_search import PopulationBasedSearch
2828from helion .autotuner .base_search import PopulationMember
2929from helion .autotuner .benchmark_job import BenchmarkJob
30- from helion .autotuner .benchmark_job import RebenchmarkJob
3130from helion .autotuner .benchmark_job import _load_args
3231from helion .autotuner .benchmark_pool import PoolBenchmarkManager
3332from helion .autotuner .benchmark_provider import LocalBenchmarkProvider
@@ -306,27 +305,23 @@ def fake_fn_b() -> None:
306305 self .assertEqual (job .warmup , 25 )
307306 self .assertEqual (job .rep , 100 )
308307
309- def test_rebenchmark_uses_worker_pool (self ) -> None :
310- # Full-effort rebenchmarking should run on the worker that precompiled the config .
308+ def test_rebenchmark_uses_owner_isolated_worker_pool (self ) -> None :
309+ # Pool rebenchmarking should isolate each candidate on its owner worker .
311310 class FakePoolManager :
312311 def __init__ (self ) -> None :
313- self .worker_index : int | None = None
314- self .job : object | None = None
315- self .timeout : float | None = None
312+ self .calls : list [tuple [int , object , float ]] = []
316313
317- def worker_index_for_fn (self , _fn : object ) -> int :
318- return 3
314+ def worker_index_for_fn (self , fn : object ) -> int :
315+ return 3 if fn is fake_fn_a else 4
319316
320317 def run_job_on_worker (
321318 self ,
322319 worker_index : int ,
323320 job : object ,
324321 timeout : float ,
325- ) -> list [float ]:
326- self .worker_index = worker_index
327- self .job = job
328- self .timeout = timeout
329- return [1.0 , 2.0 ]
322+ ) -> float :
323+ self .calls .append ((worker_index , job , timeout ))
324+ return float (worker_index )
330325
331326 class FakeLog :
332327 def warning (self , * _args : object , ** _kwargs : object ) -> None :
@@ -335,7 +330,10 @@ def warning(self, *_args: object, **_kwargs: object) -> None:
335330 def debug (self , * _args : object , ** _kwargs : object ) -> None :
336331 pass
337332
338- def fake_fn () -> None :
333+ def fake_fn_a () -> None :
334+ pass
335+
336+ def fake_fn_b () -> None :
339337 pass
340338
341339 pool = FakePoolManager ()
@@ -358,15 +356,12 @@ def fake_fn() -> None:
358356 module_name = None ,
359357 )
360358
361- result = provider .rebenchmark ([fake_fn , fake_fn ], repeat = 7 , desc = "verify" )
359+ result = provider .rebenchmark ([fake_fn_a , fake_fn_b ], repeat = 7 , desc = "verify" )
362360
363- self .assertEqual (result , [1.0 , 2.0 ])
364- self .assertEqual (pool .worker_index , 3 )
365- self .assertIsInstance (pool .job , RebenchmarkJob )
366- assert isinstance (pool .job , RebenchmarkJob )
367- self .assertEqual (pool .job .repeat , 7 )
368- self .assertEqual (len (pool .job .fn_specs ), 2 )
369- self .assertEqual (pool .timeout , 20.0 )
361+ self .assertEqual (result , [3.0 , 4.0 ])
362+ self .assertEqual ([call [0 ] for call in pool .calls ], [3 , 4 ])
363+ self .assertTrue (all (isinstance (call [1 ], BenchmarkJob ) for call in pool .calls ))
364+ self .assertTrue (all (call [2 ] == 10.0 for call in pool .calls ))
370365
371366 def test_population_rebenchmark_uses_provider_timings (self ) -> None :
372367 # BaseSearch should use provider rebenchmark timings when available.
0 commit comments