Skip to content

Commit 733ba65

Browse files
choijon5jongsokchoi
authored andcommitted
[Autotuner] Long-lived worker pool for parallel precompile
1 parent d4261ff commit 733ba65

5 files changed

Lines changed: 306 additions & 11 deletions

File tree

helion/autotuner/benchmark_job.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
"""Picklable benchmark job executed inside a ``BenchmarkWorker``."""
1+
"""Picklable benchmark jobs executed inside a ``BenchmarkWorker``."""
22

33
from __future__ import annotations
44

5+
import contextlib
56
import dataclasses
67
import functools
78
from typing import TYPE_CHECKING
@@ -24,6 +25,70 @@ def _load_args(path: str) -> Sequence[object]:
2425
return cast("Sequence[object]", torch.load(path))
2526

2627

28+
@dataclasses.dataclass
29+
class WarmupJob:
30+
"""Pre-load args + init CUDA in a worker so the first PrecompileJob
31+
doesn't pay cold-start cost."""
32+
33+
args_path: str
34+
35+
def __call__(self) -> bool:
36+
_load_args(self.args_path)
37+
if torch.cuda.is_available():
38+
torch.cuda.init()
39+
return True
40+
41+
42+
@dataclasses.dataclass
43+
class PrecompileJob:
44+
"""Compile-only precompile in a worker. Runs host-side helion code with
45+
an extract_launcher that raises before any kernel launch, then triggers
46+
Triton's compile (CPU + ptxas, no kernel execution). The binary lands
47+
in Triton's on-disk cache for the benchmark phase to reuse.
48+
49+
Mirrors fork-mode children, but inside a long-lived spawn worker so the
50+
parent never touches CUDA during prep."""
51+
52+
fn_spec: SerializedCompiledFunction
53+
args_path: str
54+
55+
def __call__(self) -> bool:
56+
from ..runtime.precompile_shim import already_compiled
57+
from ..runtime.precompile_shim import already_compiled_fail
58+
from ..runtime.precompile_shim import make_precompiler
59+
from .precompile_future import _ExtractedLaunchArgs
60+
61+
fn = _load_compiled_fn(self.fn_spec)
62+
args = _load_args(self.args_path)
63+
64+
captured: list[tuple[object, tuple[object, ...], dict[str, object]]] = []
65+
66+
def extract_launcher(
67+
triton_kernel: object,
68+
grid: tuple[int, ...],
69+
*launch_args: object,
70+
**launch_kwargs: object,
71+
) -> object:
72+
captured.append((triton_kernel, launch_args, launch_kwargs))
73+
raise _ExtractedLaunchArgs(triton_kernel, grid, launch_args, launch_kwargs)
74+
75+
with contextlib.suppress(_ExtractedLaunchArgs):
76+
fn(*args, _launcher=extract_launcher) # pyrefly: ignore[bad-argument-type]
77+
if not captured:
78+
# No kernel launch in host code -> nothing to compile.
79+
return True
80+
81+
triton_fn, launch_args, launch_kwargs = captured[0]
82+
precompiler = make_precompiler( # pyrefly: ignore[bad-argument-type]
83+
triton_fn, None, None
84+
)(*launch_args, **launch_kwargs)
85+
if precompiler is already_compiled:
86+
return True
87+
if precompiler is already_compiled_fail:
88+
return False
89+
return precompiler(in_child_process=False)
90+
91+
2792
@dataclasses.dataclass
2893
class BenchmarkJob:
2994
fn_spec: SerializedCompiledFunction

helion/autotuner/benchmark_provider.py

Lines changed: 138 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
from ..runtime.precompile_shim import already_compiled_fail
3030
from ..runtime.precompile_shim import make_precompiler
3131
from .benchmark_job import BenchmarkJob
32+
from .benchmark_job import PrecompileJob
3233
from .benchmark_worker import BenchmarkSubprocessError
3334
from .benchmark_worker import BenchmarkWorker
35+
from .benchmark_worker import BenchmarkWorkerPool
3436
from .benchmarking import do_bench
3537
from .benchmarking import synchronize_device
3638
from .logger import SUPPRESSED_TRITON_CODE_MSG
@@ -339,6 +341,7 @@ def __init__(
339341
self._precompile_args_path: str | None = None
340342
self._precompile_result_counter: count[int] = count()
341343
self._benchmark_worker: BenchmarkWorker | None = None
344+
self._worker_pool: BenchmarkWorkerPool | None = None
342345

343346
# TODO(hinriksnaer): baseline computation is expensive (compiles and runs
344347
# the kernel). Currently safe because the provider is only constructed
@@ -541,7 +544,12 @@ def _precompile_context(self) -> PrecompileContext:
541544
)
542545

543546
def setup(self) -> None:
544-
"""Prepare precompile tmpdir and args for spawn mode."""
547+
"""Prepare precompile tmpdir and args. Eagerly warms the worker pool
548+
when worker-pool precompile is enabled so spawn + ``torch.load`` cost
549+
runs concurrently with the parent's other setup work, not in the
550+
critical path of the first ``map_jobs``."""
551+
from .benchmark_job import WarmupJob
552+
545553
if self._precompile_tmpdir is None:
546554
self._precompile_tmpdir = tempfile.TemporaryDirectory()
547555
if (
@@ -552,6 +560,14 @@ def setup(self) -> None:
552560
torch.save(self.args, args_path)
553561
self._precompile_args_path = args_path
554562

563+
if self._worker_precompile_enabled():
564+
assert self._precompile_args_path is not None
565+
args_path = self._precompile_args_path
566+
self._ensure_worker_pool().warmup(
567+
lambda: WarmupJob(args_path=args_path),
568+
timeout=float(self.settings.autotune_compile_timeout),
569+
)
570+
555571
def _next_precompile_result_path(self) -> str:
556572
"""Return a fresh path for a precompile result file."""
557573
if self._precompile_tmpdir is None:
@@ -566,6 +582,9 @@ def cleanup(self) -> None:
566582
if self._benchmark_worker is not None:
567583
self._benchmark_worker.shutdown()
568584
self._benchmark_worker = None
585+
if self._worker_pool is not None:
586+
self._worker_pool.shutdown()
587+
self._worker_pool = None
569588
if self._precompile_tmpdir is not None:
570589
self._precompile_tmpdir.cleanup()
571590
self._precompile_tmpdir = None
@@ -585,6 +604,43 @@ def _subprocess_benchmark_enabled(self) -> bool:
585604
_backend = getattr(self.config_spec, "backend", None)
586605
return not (_backend is not None and _backend.get_do_bench() is not None)
587606

607+
def _worker_precompile_enabled(self) -> bool:
608+
"""Worker-pool precompile is the default safe path when subprocess
609+
benchmark is enabled and the kernel has args saved to disk. Pool size
610+
auto-decides from GPU memory + cpu count; users can override via
611+
``HELION_AUTOTUNE_PRECOMPILE_WORKERS=<n>`` (or set ``< 0`` to disable)."""
612+
return (
613+
self.settings.autotune_precompile_workers >= 0
614+
and self._subprocess_benchmark_enabled()
615+
and self._precompile_args_path is not None
616+
and self._pool_size() >= 1
617+
)
618+
619+
def _pool_size(self) -> int:
620+
"""Resolve the effective pool size. ``autotune_precompile_workers > 0``
621+
is honored verbatim. Otherwise pick ``min(cpu_count, free_mem // est)``
622+
where ``est`` accounts for compile-only per-worker memory: args + a brief
623+
output-allocation peak + CUDA driver overhead, with a 2x safety factor."""
624+
explicit = self.settings.autotune_precompile_workers
625+
if explicit > 0:
626+
return explicit
627+
cpu_cap = os.cpu_count() or 1
628+
device = self.kernel.env.device
629+
if device.type != "cuda":
630+
return cpu_cap
631+
args_bytes = _estimate_tree_bytes(self.args)
632+
per_worker_bytes = (args_bytes + max(args_bytes, 1 * 1024**3)) * 2
633+
if per_worker_bytes <= 0:
634+
return cpu_cap
635+
available_memory, _ = torch.cuda.mem_get_info(device)
636+
memory_cap = max(1, int(available_memory * 0.9) // per_worker_bytes)
637+
return min(cpu_cap, memory_cap)
638+
639+
def _ensure_worker_pool(self) -> BenchmarkWorkerPool:
640+
if self._worker_pool is None:
641+
self._worker_pool = BenchmarkWorkerPool(num_workers=self._pool_size())
642+
return self._worker_pool
643+
588644
def _validate_against_baseline(
589645
self, config: Config, output: object, args: Sequence[object]
590646
) -> bool:
@@ -676,7 +732,17 @@ def benchmark(
676732
configs = [all_configs[i] for i in valid_indices]
677733

678734
# Precompile phase
679-
if self.settings.autotune_precompile:
735+
precompile_status: list[Literal["ok", "error", "timeout"]] = []
736+
compile_times: list[float | None] = [None] * len(configs)
737+
if self._worker_precompile_enabled() and self.settings.autotune_precompile:
738+
precompile_desc = (
739+
f"{desc} precompiling" if self.settings.autotune_progress_bar else None
740+
)
741+
is_workings, precompile_status, compile_times = (
742+
self._worker_pool_precompile(configs, fns, precompile_desc)
743+
)
744+
futures = None
745+
elif self.settings.autotune_precompile:
680746
futures = list(
681747
starmap(
682748
self._create_precompile_future,
@@ -687,7 +753,6 @@ def benchmark(
687753
f"{desc} precompiling" if self.settings.autotune_progress_bar else None
688754
)
689755
is_workings = PrecompileFuture.wait_for_all(futures, desc=precompile_desc)
690-
precompile_status: list[Literal["ok", "error", "timeout"]] = []
691756
for future, ok in zip(futures, is_workings, strict=True):
692757
reason = future.failure_reason
693758
if ok:
@@ -697,6 +762,7 @@ def benchmark(
697762
else:
698763
precompile_status.append("error")
699764
else:
765+
futures = None
700766
is_workings = [True] * len(configs)
701767
precompile_status = ["ok"] * len(configs)
702768

@@ -725,7 +791,7 @@ def benchmark(
725791
else None
726792
)
727793
else:
728-
compile_time = None
794+
compile_time = compile_times[index]
729795
status: Literal[
730796
"ok", "error", "timeout", "peer_compilation_fail", "filtered"
731797
]
@@ -954,6 +1020,65 @@ def _benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
9541020
self._autotune_metrics.num_compile_failures += 1
9551021
return inf
9561022

1023+
def _worker_pool_precompile(
1024+
self,
1025+
configs: list[Config],
1026+
fns: list[CompiledConfig],
1027+
desc: str | None,
1028+
) -> tuple[
1029+
list[bool],
1030+
list[Literal["ok", "error", "timeout"]],
1031+
list[float | None],
1032+
]:
1033+
"""Compile each config in the long-lived worker pool. Returns
1034+
``(is_workings, statuses, compile_times)`` aligned with ``configs``."""
1035+
assert self._precompile_args_path is not None
1036+
args_path = self._precompile_args_path
1037+
timeout = float(self.settings.autotune_compile_timeout)
1038+
1039+
# Build PrecompileJobs; serialization failures count as compile failures.
1040+
jobs: list[PrecompileJob | None] = []
1041+
for fn in fns:
1042+
try:
1043+
jobs.append(
1044+
PrecompileJob(
1045+
fn_spec=_serialize_compiled_fn(fn), args_path=args_path
1046+
)
1047+
)
1048+
except RuntimeError:
1049+
jobs.append(None)
1050+
1051+
live_idxs = [i for i, j in enumerate(jobs) if j is not None]
1052+
live_jobs = cast("list[Callable[[], object]]", [jobs[i] for i in live_idxs])
1053+
t0 = time.perf_counter()
1054+
live_results = self._ensure_worker_pool().map_jobs(live_jobs, timeout=timeout)
1055+
elapsed = time.perf_counter() - t0
1056+
1057+
is_workings = [False] * len(configs)
1058+
statuses: list[Literal["ok", "error", "timeout"]] = ["error"] * len(configs)
1059+
compile_times: list[float | None] = [None] * len(configs)
1060+
for idx, result in zip(live_idxs, live_results, strict=True):
1061+
compile_times[idx] = elapsed
1062+
if isinstance(result, BaseException):
1063+
statuses[idx] = (
1064+
"timeout"
1065+
if isinstance(result, BenchmarkSubprocessError)
1066+
and "timeout" in str(result).lower()
1067+
else "error"
1068+
)
1069+
self.log.debug(
1070+
f"Precompile worker failed for {configs[idx]!r}: "
1071+
f"{type(result).__name__}: {result}"
1072+
)
1073+
self._autotune_metrics.num_compile_failures += 1
1074+
else:
1075+
is_workings[idx] = True
1076+
statuses[idx] = "ok"
1077+
1078+
if desc:
1079+
self.log(f"{desc} 100% via worker pool ({len(live_idxs)} configs)")
1080+
return is_workings, statuses, compile_times
1081+
9571082
def _benchmark_function_subprocess(
9581083
self, config: Config, fn: CompiledConfig
9591084
) -> float | None:
@@ -969,8 +1094,14 @@ def _benchmark_function_subprocess(
9691094
except RuntimeError:
9701095
return None
9711096

972-
if self._benchmark_worker is None:
973-
self._benchmark_worker = BenchmarkWorker(device=None)
1097+
# Prefer the pool's first worker if a pool is active so the same CUDA
1098+
# context that compiled also benchmarks (Triton cache hit, no recompile).
1099+
if self._worker_pool is not None:
1100+
run_in_worker = lambda j, t: self._worker_pool.run_one(j, timeout=t) # noqa: E731
1101+
else:
1102+
if self._benchmark_worker is None:
1103+
self._benchmark_worker = BenchmarkWorker(device=None)
1104+
run_in_worker = lambda j, t: self._benchmark_worker.run(j, timeout=t) # noqa: E731
9741105

9751106
job = BenchmarkJob(
9761107
fn_spec=fn_spec,
@@ -981,7 +1112,7 @@ def _benchmark_function_subprocess(
9811112
timeout = float(self.settings.autotune_benchmark_timeout)
9821113

9831114
try:
984-
latency = self._benchmark_worker.run(job, timeout=timeout)
1115+
latency = run_in_worker(job, timeout)
9851116
except BenchmarkSubprocessError as e:
9861117
# Timeout or unexpected worker exit; skip config and continue.
9871118
self.log.warning(f"Benchmark subprocess failed for {config!r}: {e}")

0 commit comments

Comments
 (0)