Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 83 additions & 4 deletions helion/autotuner/benchmark_job.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,89 @@
"""Picklable benchmark job executed inside a ``BenchmarkWorker``."""
"""Picklable benchmark jobs executed inside a ``BenchmarkWorker``."""

from __future__ import annotations

import contextlib
import dataclasses
import functools
from typing import TYPE_CHECKING
from typing import Any
from typing import cast

import torch

from .benchmarking import do_bench
from .benchmarking import interleaved_bench
from .precompile_future import _load_compiled_fn

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Sequence

from .precompile_future import SerializedCompiledFunction


@functools.cache
@functools.lru_cache(maxsize=2)
def _load_args(path: str) -> Sequence[object]:
# Cached so re-spawning configs don't re-read the same args off disk.
return cast("Sequence[object]", torch.load(path))
"""Load Helion-created benchmark args in a worker process.

The cache is intentionally tiny: process-level pools see multiple shapes,
but each worker should only retain the latest args. ``weights_only=False``
is required because kernel args can include callables such as epilogues.
"""
return cast("Sequence[object]", torch.load(path, weights_only=False))


@dataclasses.dataclass
class PrecompileJob:
"""Compile-only precompile in a worker. Runs host-side helion code with
an extract_launcher that raises before any kernel launch, then triggers
Triton's compile (CPU + ptxas, no kernel execution). The binary lands
in Triton's on-disk cache for the benchmark phase to reuse.

Mirrors fork-mode children, but inside a long-lived spawn worker so the
parent never touches CUDA during prep."""

fn_spec: SerializedCompiledFunction
args_path: str

def __call__(self) -> bool:
from ..runtime.precompile_shim import already_compiled
from ..runtime.precompile_shim import already_compiled_fail
from ..runtime.precompile_shim import make_precompiler
from .precompile_future import _ExtractedLaunchArgs

fn = _load_compiled_fn(self.fn_spec)
args = _load_args(self.args_path)

captured: list[tuple[object, tuple[object, ...], dict[str, object]]] = []

def extract_launcher(
triton_kernel: object,
grid: tuple[int, ...],
*launch_args: object,
**launch_kwargs: object,
) -> object:
captured.append((triton_kernel, launch_args, launch_kwargs))
raise _ExtractedLaunchArgs(triton_kernel, grid, launch_args, launch_kwargs)

with contextlib.suppress(_ExtractedLaunchArgs):
fn(*args, _launcher=extract_launcher) # pyrefly: ignore[bad-argument-type]
if not captured:
# No kernel launch in host code -> nothing to compile.
return True

triton_fn, launch_args, launch_kwargs = captured[0]
precompiler = cast(
"Callable[..., bool]",
make_precompiler(cast("Any", triton_fn), None, None)(
*launch_args, **launch_kwargs
),
)
if precompiler is already_compiled:
return True
if precompiler is already_compiled_fail:
return False
return precompiler(in_child_process=False)


@dataclasses.dataclass
Expand All @@ -44,3 +106,20 @@ def __call__(self) -> float:
rep=self.rep,
),
)


@dataclasses.dataclass
class RebenchmarkJob:
"""Run interleaved rebenchmarking in the same isolated worker path."""

fn_specs: list[SerializedCompiledFunction]
args_path: str
repeat: int

def __call__(self) -> list[float]:
args = _load_args(self.args_path)
fns: list[Callable[[], object]] = [
functools.partial(_load_compiled_fn(fn_spec), *args)
for fn_spec in self.fn_specs
]
return interleaved_bench(fns, repeat=self.repeat)
104 changes: 104 additions & 0 deletions helion/autotuner/benchmark_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
import ctypes.util
import multiprocessing as mp
import os
import queue
import signal
import sys
import threading
import time
from typing import TYPE_CHECKING
from typing import Callable
from typing import NamedTuple
from typing import TypeVar

from .logger import _UNRECOVERABLE_RUNTIME_ERROR_RE
Expand All @@ -21,6 +25,12 @@
_T = TypeVar("_T")


class WorkerPoolResult(NamedTuple):
worker_index: int
elapsed: float
result: object


def _set_pdeathsig() -> None:
"""SIGTERM the child if the parent dies (Linux only, best-effort)."""
if sys.platform != "linux":
Expand Down Expand Up @@ -146,3 +156,97 @@ def _kill(self) -> None:
connection.close()
self._process = None
self._parent_connection = None


class BenchmarkWorkerPool:
"""Pool of long-lived ``BenchmarkWorker`` processes."""

def __init__(self, num_workers: int) -> None:
if num_workers < 1:
raise ValueError(f"num_workers must be >= 1, got {num_workers}")
self.workers = [BenchmarkWorker(device=None) for _ in range(num_workers)]

@property
def num_workers(self) -> int:
return len(self.workers)

def run_job_on_worker(
self, worker_index: int, job: Callable[[], _T], timeout: float
) -> _T:
return self.workers[worker_index % self.num_workers].run(job, timeout=timeout)

def run_jobs(
self, jobs: list[Callable[[], object]], timeout: float
) -> list[WorkerPoolResult]:
"""Run jobs across the worker pool while preserving input order.

Each worker thread owns one worker and pulls job indices from a shared
queue, so slow jobs do not block unrelated workers. Worker exceptions
are captured in ``WorkerPoolResult.result`` for the corresponding job.
"""
if not jobs:
return []
active_workers = min(self.num_workers, len(jobs))
result_slots: list[WorkerPoolResult | None] = [None] * len(jobs)
work_queue: queue.Queue[int] = queue.Queue()
for i in range(len(jobs)):
work_queue.put(i)

def process_queue(worker_idx: int) -> None:
worker = self.workers[worker_idx]
while True:
try:
i = work_queue.get_nowait()
except queue.Empty:
return
start = time.perf_counter()
job_result = _run_job_capture_error(worker, jobs[i], timeout)
result_slots[i] = WorkerPoolResult(
worker_index=worker_idx,
elapsed=time.perf_counter() - start,
result=job_result,
)

_run_worker_threads(process_queue, active_workers)
ordered_results: list[WorkerPoolResult] = []
for slot in result_slots:
assert slot is not None
ordered_results.append(slot)
return ordered_results

def start_all(self, limit: int | None = None) -> None:
"""Start workers before threaded dispatch so their lifetime is not
tied to short-lived dispatch threads."""
if limit is None:
limit = self.num_workers
for worker in self.workers[:limit]:
if not worker.alive():
worker._start()

def shutdown(self) -> None:
for w in self.workers:
with contextlib.suppress(Exception):
w.shutdown()


def _run_job_capture_error(
worker: BenchmarkWorker, job: Callable[[], object], timeout: float
) -> object:
try:
return worker.run(job, timeout=timeout)
except BaseException as e:
e.__traceback__ = None
return e


def _run_worker_threads(target: Callable[[int], None], n: int) -> None:
if n == 1:
target(0)
return
threads = [
threading.Thread(target=target, args=(i,), daemon=True) for i in range(n)
]
for t in threads:
t.start()
for t in threads:
t.join()
57 changes: 44 additions & 13 deletions helion/autotuner/precompile_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import collections
import contextlib
import dataclasses
import functools
import hashlib
import inspect
import multiprocessing as mp
from multiprocessing import connection
Expand All @@ -20,7 +22,6 @@
from typing import Literal
from typing import NoReturn
from typing import cast
import uuid

import torch

Expand Down Expand Up @@ -144,20 +145,43 @@ def _serialize_compiled_fn(fn: CompiledConfig) -> SerializedCompiledFunction:


def _load_compiled_fn(fn_spec: SerializedCompiledFunction) -> CompiledConfig:
module_name = f"_helion_autotune_subprocess_{uuid.uuid4().hex}"
module = types.ModuleType(module_name)
module.__file__ = fn_spec.filename or "<helion-autotune-subprocess>"
module.__loader__ = None
module.__package__ = None
sys.modules[module_name] = module
exec(
compile(fn_spec.source_code, module.__file__, "exec"),
module.__dict__,
return _load_compiled_fn_cached(
fn_spec.function_name,
fn_spec.source_code,
fn_spec.filename,
)
fn = getattr(module, fn_spec.function_name, None)


# Generated modules retain code objects in long-lived workers, so keep this bounded.
@functools.lru_cache(maxsize=256)
def _load_compiled_fn_cached(
function_name: str,
source_code: str,
filename: str | None,
) -> CompiledConfig:
digest = hashlib.sha256()
digest.update(function_name.encode("utf-8"))
digest.update(b"\0")
digest.update((filename or "").encode("utf-8"))
digest.update(b"\0")
digest.update(source_code.encode("utf-8"))
module_name = f"_helion_autotune_subprocess_{digest.hexdigest()}"
if module_name in sys.modules:
module = sys.modules[module_name]
else:
module = types.ModuleType(module_name)
module.__file__ = filename or "<helion-autotune-subprocess>"
module.__loader__ = None
module.__package__ = None
sys.modules[module_name] = module
exec(
compile(source_code, module.__file__, "exec"),
module.__dict__,
)
fn = getattr(module, function_name, None)
if fn is None:
raise RuntimeError(
f"Unable to locate compiled kernel '{fn_spec.function_name}' in generated module"
f"Unable to locate compiled kernel '{function_name}' in generated module"
)
return fn

Expand All @@ -172,7 +196,10 @@ def _run_kernel_in_subprocess_spawn(
_cap: list[str] = [""]
try:
fn = _load_compiled_fn(fn_spec)
args = torch.load(args_path)
# The args file is created by the current autotune run in a private
# TemporaryDirectory. It may legitimately contain callable kernel args
# such as matmul epilogues, which PyTorch's weights_only loader rejects.
args = torch.load(args_path, weights_only=False)
assert isinstance(args, (tuple, list))
synchronize_device(None)
with capture_output() as _cap:
Expand Down Expand Up @@ -394,6 +421,10 @@ def create(
mode = ctx.settings.autotune_precompile
decorator = ctx.kernel.format_kernel_decorator(config, ctx.settings)

if mode == "pool":
raise exc.InvalidAPIUsage(
"autotune_precompile='pool' is handled by the benchmark worker pool"
)
if mode == "spawn":
mp_ctx = mp.get_context("spawn")
assert args_path is not None
Expand Down
22 changes: 17 additions & 5 deletions helion/runtime/precompile_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ def _get_helion_compilation_success(kernel: CompiledKernel) -> bool:

def make_precompiler(
fn: JITFunction[object],
config: Config,
bound_kernel: BoundKernel,
config: Config | None,
bound_kernel: BoundKernel | None,
) -> Callable[..., Callable[[], bool]]:
"""``bound_kernel``/``config`` may be ``None`` when invoked from a
subprocess that has only the Triton fn + launch args (e.g. the pool
worker); their only use here is error formatting."""
from .kernel import _find_device

def _make_precompiler(*args: object, **kwargs: object) -> Callable[[], bool]:
Expand All @@ -50,9 +53,14 @@ def _make_precompiler(*args: object, **kwargs: object) -> Callable[[], bool]:
kwargs["debug"] = (
kwargs.get("debug", fn.debug) or os.environ.get("TRITON_DEBUG", "0") == "1"
)
kernel_cache, *_, target, backend, binder = fn.device_caches[device]
kernel_cache, *cache_parts, target, backend, binder = fn.device_caches[device]
bound_args, specialization, options = binder(*args, **kwargs)
key = str(specialization) + str(options)
if cache_parts:
from triton.runtime.jit import compute_cache_key

key = compute_cache_key(cache_parts[0], specialization, options)
else:
key = str(specialization) + str(options)
kernel = kernel_cache.get(key, None)
if kernel is not None:
return (
Expand Down Expand Up @@ -98,7 +106,11 @@ def finish_it(in_child_process: bool = True) -> bool:
compiled_kernel._init_handles()
except Exception as e:
action = classify_triton_exception(e)
if action != "debug":
if (
action != "debug"
and bound_kernel is not None
and config is not None
):
print(
format_triton_compile_failure(config, e, bound_kernel),
file=sys.stderr,
Expand Down
Loading
Loading