Skip to content

Commit 2bcd8e4

Browse files
committed
[Autotuner] Add long-lived benchmark worker pool
stack-info: PR: #2289, branch: choijon5/stack/46
1 parent e9d6c02 commit 2bcd8e4

2 files changed

Lines changed: 123 additions & 2 deletions

File tree

helion/autotuner/benchmark_worker.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
import ctypes.util
88
import multiprocessing as mp
99
import os
10+
import queue
1011
import signal
1112
import sys
13+
import threading
14+
import time
1215
from typing import TYPE_CHECKING
1316
from typing import Callable
17+
from typing import NamedTuple
1418
from typing import TypeVar
1519

1620
from .logger import _UNRECOVERABLE_RUNTIME_ERROR_RE
@@ -21,6 +25,12 @@
2125
_T = TypeVar("_T")
2226

2327

28+
class WorkerPoolResult(NamedTuple):
29+
worker_index: int
30+
elapsed: float
31+
result: object
32+
33+
2434
def _set_pdeathsig() -> None:
2535
"""SIGTERM the child if the parent dies (Linux only, best-effort)."""
2636
if sys.platform != "linux":
@@ -146,3 +156,97 @@ def _kill(self) -> None:
146156
connection.close()
147157
self._process = None
148158
self._parent_connection = None
159+
160+
161+
class BenchmarkWorkerPool:
162+
"""Pool of long-lived ``BenchmarkWorker`` processes."""
163+
164+
def __init__(self, num_workers: int) -> None:
165+
if num_workers < 1:
166+
raise ValueError(f"num_workers must be >= 1, got {num_workers}")
167+
self.workers = [BenchmarkWorker(device=None) for _ in range(num_workers)]
168+
169+
@property
170+
def num_workers(self) -> int:
171+
return len(self.workers)
172+
173+
def run_job_on_worker(
174+
self, worker_index: int, job: Callable[[], _T], timeout: float
175+
) -> _T:
176+
return self.workers[worker_index % self.num_workers].run(job, timeout=timeout)
177+
178+
def run_jobs(
179+
self, jobs: list[Callable[[], object]], timeout: float
180+
) -> list[WorkerPoolResult]:
181+
"""Run jobs across the worker pool while preserving input order.
182+
183+
Each worker thread owns one worker and pulls job indices from a shared
184+
queue, so slow jobs do not block unrelated workers. Worker exceptions
185+
are captured in ``WorkerPoolResult.result`` for the corresponding job.
186+
"""
187+
if not jobs:
188+
return []
189+
active_workers = min(self.num_workers, len(jobs))
190+
result_slots: list[WorkerPoolResult | None] = [None] * len(jobs)
191+
work_queue: queue.Queue[int] = queue.Queue()
192+
for i in range(len(jobs)):
193+
work_queue.put(i)
194+
195+
def process_queue(worker_idx: int) -> None:
196+
worker = self.workers[worker_idx]
197+
while True:
198+
try:
199+
i = work_queue.get_nowait()
200+
except queue.Empty:
201+
return
202+
start = time.perf_counter()
203+
job_result = _run_job_capture_error(worker, jobs[i], timeout)
204+
result_slots[i] = WorkerPoolResult(
205+
worker_index=worker_idx,
206+
elapsed=time.perf_counter() - start,
207+
result=job_result,
208+
)
209+
210+
_run_worker_threads(process_queue, active_workers)
211+
ordered_results: list[WorkerPoolResult] = []
212+
for slot in result_slots:
213+
assert slot is not None
214+
ordered_results.append(slot)
215+
return ordered_results
216+
217+
def start_all(self, limit: int | None = None) -> None:
218+
"""Start workers before threaded dispatch so their lifetime is not
219+
tied to short-lived dispatch threads."""
220+
if limit is None:
221+
limit = self.num_workers
222+
for worker in self.workers[:limit]:
223+
if not worker.alive():
224+
worker._start()
225+
226+
def shutdown(self) -> None:
227+
for w in self.workers:
228+
with contextlib.suppress(Exception):
229+
w.shutdown()
230+
231+
232+
def _run_job_capture_error(
233+
worker: BenchmarkWorker, job: Callable[[], object], timeout: float
234+
) -> object:
235+
try:
236+
return worker.run(job, timeout=timeout)
237+
except BaseException as e:
238+
e.__traceback__ = None
239+
return e
240+
241+
242+
def _run_worker_threads(target: Callable[[int], None], n: int) -> None:
243+
if n == 1:
244+
target(0)
245+
return
246+
threads = [
247+
threading.Thread(target=target, args=(i,), daemon=True) for i in range(n)
248+
]
249+
for t in threads:
250+
t.start()
251+
for t in threads:
252+
t.join()

test/test_benchmark_worker.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from helion.autotuner.benchmark_worker import BenchmarkTimeout
2525
from helion.autotuner.benchmark_worker import BenchmarkWorker
2626
from helion.autotuner.benchmark_worker import BenchmarkWorkerDied
27+
from helion.autotuner.benchmark_worker import BenchmarkWorkerPool
2728
from helion.autotuner.random_search import RandomSearch
2829

2930
if TYPE_CHECKING:
@@ -69,6 +70,7 @@ def __call__(self) -> object:
6970

7071
class TestBenchmarkWorkerFailureModes(unittest.TestCase):
7172
def test_timeout_kills_worker(self) -> None:
73+
# A timed-out job should kill the worker and the next job should respawn it.
7274
worker = BenchmarkWorker()
7375
try:
7476
t0 = time.time()
@@ -82,8 +84,7 @@ def test_timeout_kills_worker(self) -> None:
8284
worker.shutdown()
8385

8486
def test_sticky_error_kills_worker(self) -> None:
85-
# Errors matching _UNRECOVERABLE_RUNTIME_ERROR_RE force the worker
86-
# to be killed so the next call spawns a fresh CUDA context.
87+
# Sticky CUDA-style errors should kill the worker before the next job.
8788
worker = BenchmarkWorker()
8889
try:
8990
with self.assertRaises(RuntimeError) as ctx:
@@ -95,6 +96,7 @@ def test_sticky_error_kills_worker(self) -> None:
9596
worker.shutdown()
9697

9798
def test_worker_crash_raises_died(self) -> None:
99+
# A worker process crash should surface as BenchmarkWorkerDied.
98100
worker = BenchmarkWorker()
99101
try:
100102
with self.assertRaises(BenchmarkWorkerDied):
@@ -103,6 +105,21 @@ def test_worker_crash_raises_died(self) -> None:
103105
finally:
104106
worker.shutdown()
105107

108+
def test_pool_run_jobs_reports_worker_and_elapsed(self) -> None:
109+
# Pool job execution should preserve result order and report timing metadata.
110+
pool = BenchmarkWorkerPool(2)
111+
try:
112+
results = pool.run_jobs(
113+
[_ReturnValue("a"), _ReturnValue("b")],
114+
timeout=30.0,
115+
)
116+
finally:
117+
pool.shutdown()
118+
119+
self.assertEqual([r.result for r in results], ["a", "b"])
120+
self.assertTrue(all(0 <= r.worker_index < 2 for r in results))
121+
self.assertTrue(all(r.elapsed >= 0 for r in results))
122+
106123

107124
# Subprocess benchmarking depends on Backend.supports_precompile(); only the
108125
# Triton backend supports it (Pallas/CuTe return False).

0 commit comments

Comments
 (0)