Skip to content

Commit cf6342d

Browse files
committed
[Autotuner] Add pool precompile manager
1 parent f0b7419 commit cf6342d

2 files changed

Lines changed: 206 additions & 1 deletion

File tree

helion/autotuner/benchmark_pool.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
from typing import Callable
5+
from typing import Literal
6+
from typing import NamedTuple
7+
from typing import TypeVar
8+
from typing import cast
9+
10+
import torch
11+
from torch.utils._pytree import tree_map_only
12+
13+
from .benchmark_job import PrecompileJob
14+
from .benchmark_worker import BenchmarkSubprocessError
15+
from .benchmark_worker import BenchmarkWorkerPool
16+
17+
if TYPE_CHECKING:
18+
from ..runtime.config import Config
19+
from ..runtime.kernel import CompiledConfig
20+
from .logger import AutotuningLogger
21+
from .metrics import AutotuneMetrics
22+
from .precompile_future import SerializedCompiledFunction
23+
24+
_T = TypeVar("_T")
25+
26+
27+
class PoolPrecompileResult(NamedTuple):
28+
is_workings: list[bool]
29+
statuses: list[Literal["ok", "error", "timeout"]]
30+
compile_times: list[float | None]
31+
32+
33+
def estimate_tree_bytes(obj: object) -> int:
34+
"""Estimate pytree tensor storage, counting shared storage once."""
35+
total = 0
36+
seen_ptrs: set[int] = set()
37+
38+
def _accumulate(tensor: torch.Tensor) -> torch.Tensor:
39+
nonlocal total
40+
size = tensor.element_size() * tensor.numel()
41+
try:
42+
storage = tensor.untyped_storage()
43+
except RuntimeError:
44+
pass
45+
else:
46+
ptr = storage.data_ptr()
47+
if ptr in seen_ptrs:
48+
return tensor
49+
seen_ptrs.add(ptr)
50+
size = storage.nbytes()
51+
total += size
52+
return tensor
53+
54+
tree_map_only(torch.Tensor, _accumulate, obj)
55+
return total
56+
57+
58+
class PoolBenchmarkManager:
59+
"""Owns the long-lived worker pool for one autotune call."""
60+
61+
def __init__(
62+
self,
63+
*,
64+
num_workers: int,
65+
log: AutotuningLogger,
66+
autotune_metrics: AutotuneMetrics,
67+
) -> None:
68+
self._pool = BenchmarkWorkerPool(num_workers)
69+
self._log = log
70+
self._autotune_metrics = autotune_metrics
71+
self._precompile_worker_by_fn: dict[int, int] = {}
72+
73+
def shutdown(self) -> None:
74+
self._pool.shutdown()
75+
self._precompile_worker_by_fn.clear()
76+
77+
def worker_index_for_fn(self, fn: Callable[..., object]) -> int:
78+
return self._precompile_worker_by_fn.get(id(fn), 0)
79+
80+
def run_on(self, worker_index: int, job: Callable[[], _T], timeout: float) -> _T:
81+
return self._pool.run_on(worker_index, job, timeout=timeout)
82+
83+
def precompile(
84+
self,
85+
configs: list[Config],
86+
fns: list[CompiledConfig],
87+
*,
88+
args_path: str,
89+
timeout: float,
90+
desc: str | None,
91+
serialize_fn: Callable[[CompiledConfig], SerializedCompiledFunction | None],
92+
) -> PoolPrecompileResult:
93+
"""Compile each config in the worker pool."""
94+
jobs: list[PrecompileJob | None] = []
95+
for fn in fns:
96+
fn_spec = serialize_fn(fn)
97+
jobs.append(
98+
PrecompileJob(fn_spec=fn_spec, args_path=args_path)
99+
if fn_spec is not None
100+
else None
101+
)
102+
103+
live_idxs = [i for i, job in enumerate(jobs) if job is not None]
104+
live_jobs = cast("list[Callable[[], object]]", [jobs[i] for i in live_idxs])
105+
self._pool.start_all(limit=len(live_jobs))
106+
live_results = self._pool.map_jobs(live_jobs, timeout=timeout)
107+
108+
is_workings = [False] * len(configs)
109+
statuses: list[Literal["ok", "error", "timeout"]] = ["error"] * len(configs)
110+
compile_times: list[float | None] = [None] * len(configs)
111+
for idx, job in enumerate(jobs):
112+
if job is None:
113+
self._log.debug(
114+
f"Precompile worker could not serialize {configs[idx]!r}"
115+
)
116+
self._autotune_metrics.num_compile_failures += 1
117+
118+
for idx, result in zip(live_idxs, live_results, strict=True):
119+
compile_times[idx] = result.elapsed
120+
job_result = result.result
121+
if isinstance(job_result, BaseException):
122+
statuses[idx] = (
123+
"timeout"
124+
if isinstance(job_result, BenchmarkSubprocessError)
125+
and "timeout" in str(job_result).lower()
126+
else "error"
127+
)
128+
self._log.debug(
129+
f"Precompile worker failed for {configs[idx]!r}: "
130+
f"{type(job_result).__name__}: {job_result}"
131+
)
132+
self._autotune_metrics.num_compile_failures += 1
133+
elif job_result is True:
134+
is_workings[idx] = True
135+
statuses[idx] = "ok"
136+
self._precompile_worker_by_fn[id(fns[idx])] = result.worker_index
137+
else:
138+
self._log.debug(
139+
f"Precompile worker returned failure for {configs[idx]!r}: "
140+
f"{job_result!r}"
141+
)
142+
self._autotune_metrics.num_compile_failures += 1
143+
144+
if desc:
145+
self._log(f"{desc} 100% via worker pool ({len(live_idxs)} configs)")
146+
return PoolPrecompileResult(is_workings, statuses, compile_times)

test/test_benchmark_worker.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
import signal
1111
import tempfile
1212
import time
13+
from types import SimpleNamespace
1314
from typing import TYPE_CHECKING
15+
from typing import Any
16+
from typing import cast
1417
import unittest
1518
from unittest.mock import patch
1619

@@ -22,15 +25,18 @@
2225
from helion._testing import onlyBackends
2326
from helion._testing import skipIfXPU
2427
from helion.autotuner.benchmark_job import _load_args
28+
from helion.autotuner.benchmark_pool import PoolBenchmarkManager
2529
from helion.autotuner.benchmark_provider import LocalBenchmarkProvider
2630
from helion.autotuner.benchmark_worker import BenchmarkTimeout
2731
from helion.autotuner.benchmark_worker import BenchmarkWorker
2832
from helion.autotuner.benchmark_worker import BenchmarkWorkerDied
2933
from helion.autotuner.benchmark_worker import BenchmarkWorkerPool
34+
from helion.autotuner.benchmark_worker import WorkerPoolResult
35+
from helion.autotuner.precompile_future import SerializedCompiledFunction
3036
from helion.autotuner.random_search import RandomSearch
37+
from helion.runtime.config import Config
3138

3239
if TYPE_CHECKING:
33-
from helion.runtime.config import Config
3440
from helion.runtime.kernel import CompiledConfig
3541

3642

@@ -141,6 +147,59 @@ def test_worker_arg_loading_allows_callable_kernel_args(self) -> None:
141147

142148
self.assertIs(loaded[0], _callable_kernel_arg)
143149

150+
def test_false_precompile_result_is_failure(self) -> None:
151+
# A worker precompile returning False should count as a real compile failure.
152+
class FakePool:
153+
def start_all(self, limit: int | None = None) -> None:
154+
self.limit = limit
155+
156+
def map_jobs(
157+
self,
158+
jobs: list[object],
159+
timeout: float,
160+
) -> list[WorkerPoolResult]:
161+
return [
162+
WorkerPoolResult(worker_index=0, elapsed=0.25, result=False)
163+
for _ in jobs
164+
]
165+
166+
class FakeLog:
167+
def debug(self, *_args: object, **_kwargs: object) -> None:
168+
pass
169+
170+
def fake_fn() -> None:
171+
pass
172+
173+
metrics = SimpleNamespace(num_compile_failures=0)
174+
manager = cast("Any", PoolBenchmarkManager.__new__(PoolBenchmarkManager))
175+
manager._pool = FakePool()
176+
manager._log = FakeLog()
177+
manager._autotune_metrics = metrics
178+
manager._precompile_worker_by_fn = {}
179+
180+
def serialize_fn(_fn: object) -> SerializedCompiledFunction:
181+
return SerializedCompiledFunction(
182+
function_name="fake_fn",
183+
source_code="def fake_fn(): pass",
184+
filename=None,
185+
module_name=None,
186+
)
187+
188+
result = manager.precompile(
189+
[Config()],
190+
[fake_fn],
191+
args_path="args.pt",
192+
timeout=1,
193+
desc=None,
194+
serialize_fn=serialize_fn,
195+
)
196+
197+
self.assertEqual(result.is_workings, [False])
198+
self.assertEqual(result.statuses, ["error"])
199+
self.assertEqual(result.compile_times, [0.25])
200+
self.assertEqual(metrics.num_compile_failures, 1)
201+
self.assertEqual(manager._precompile_worker_by_fn, {})
202+
144203

145204
# Subprocess benchmarking depends on Backend.supports_precompile(); only the
146205
# Triton backend supports it (Pallas/CuTe return False).

0 commit comments

Comments
 (0)