2929from ..runtime .precompile_shim import already_compiled_fail
3030from ..runtime .precompile_shim import make_precompiler
3131from .benchmark_job import BenchmarkJob
32+ from .benchmark_job import PrecompileJob
3233from .benchmark_worker import BenchmarkSubprocessError
3334from .benchmark_worker import BenchmarkWorker
35+ from .benchmark_worker import BenchmarkWorkerPool
3436from .benchmarking import do_bench
3537from .benchmarking import synchronize_device
3638from .logger import SUPPRESSED_TRITON_CODE_MSG
@@ -339,6 +341,7 @@ def __init__(
339341 self ._precompile_args_path : str | None = None
340342 self ._precompile_result_counter : count [int ] = count ()
341343 self ._benchmark_worker : BenchmarkWorker | None = None
344+ self ._worker_pool : BenchmarkWorkerPool | None = None
342345
343346 # TODO(hinriksnaer): baseline computation is expensive (compiles and runs
344347 # the kernel). Currently safe because the provider is only constructed
@@ -541,7 +544,12 @@ def _precompile_context(self) -> PrecompileContext:
541544 )
542545
543546 def setup (self ) -> None :
544- """Prepare precompile tmpdir and args for spawn mode."""
547+ """Prepare precompile tmpdir and args. Eagerly warms the worker pool
548+ when worker-pool precompile is enabled so spawn + ``torch.load`` cost
549+ runs concurrently with the parent's other setup work, not in the
550+ critical path of the first ``map_jobs``."""
551+ from .benchmark_job import WarmupJob
552+
545553 if self ._precompile_tmpdir is None :
546554 self ._precompile_tmpdir = tempfile .TemporaryDirectory ()
547555 if (
@@ -552,6 +560,14 @@ def setup(self) -> None:
552560 torch .save (self .args , args_path )
553561 self ._precompile_args_path = args_path
554562
563+ if self ._worker_precompile_enabled ():
564+ assert self ._precompile_args_path is not None
565+ args_path = self ._precompile_args_path
566+ self ._ensure_worker_pool ().warmup (
567+ lambda : WarmupJob (args_path = args_path ),
568+ timeout = float (self .settings .autotune_compile_timeout ),
569+ )
570+
555571 def _next_precompile_result_path (self ) -> str :
556572 """Return a fresh path for a precompile result file."""
557573 if self ._precompile_tmpdir is None :
@@ -566,6 +582,9 @@ def cleanup(self) -> None:
566582 if self ._benchmark_worker is not None :
567583 self ._benchmark_worker .shutdown ()
568584 self ._benchmark_worker = None
585+ if self ._worker_pool is not None :
586+ self ._worker_pool .shutdown ()
587+ self ._worker_pool = None
569588 if self ._precompile_tmpdir is not None :
570589 self ._precompile_tmpdir .cleanup ()
571590 self ._precompile_tmpdir = None
@@ -585,6 +604,43 @@ def _subprocess_benchmark_enabled(self) -> bool:
585604 _backend = getattr (self .config_spec , "backend" , None )
586605 return not (_backend is not None and _backend .get_do_bench () is not None )
587606
607+ def _worker_precompile_enabled (self ) -> bool :
608+ """Worker-pool precompile is the default safe path when subprocess
609+ benchmark is enabled and the kernel has args saved to disk. Pool size
610+ auto-decides from GPU memory + cpu count; users can override via
611+ ``HELION_AUTOTUNE_PRECOMPILE_WORKERS=<n>`` (or set ``< 0`` to disable)."""
612+ return (
613+ self .settings .autotune_precompile_workers >= 0
614+ and self ._subprocess_benchmark_enabled ()
615+ and self ._precompile_args_path is not None
616+ and self ._pool_size () >= 1
617+ )
618+
619+ def _pool_size (self ) -> int :
620+ """Resolve the effective pool size. ``autotune_precompile_workers > 0``
621+ is honored verbatim. Otherwise pick ``min(cpu_count, free_mem // est)``
622+ where ``est`` accounts for compile-only per-worker memory: args + a brief
623+ output-allocation peak + CUDA driver overhead, with a 2x safety factor."""
624+ explicit = self .settings .autotune_precompile_workers
625+ if explicit > 0 :
626+ return explicit
627+ cpu_cap = os .cpu_count () or 1
628+ device = self .kernel .env .device
629+ if device .type != "cuda" :
630+ return cpu_cap
631+ args_bytes = _estimate_tree_bytes (self .args )
632+ per_worker_bytes = (args_bytes + max (args_bytes , 1 * 1024 ** 3 )) * 2
633+ if per_worker_bytes <= 0 :
634+ return cpu_cap
635+ available_memory , _ = torch .cuda .mem_get_info (device )
636+ memory_cap = max (1 , int (available_memory * 0.9 ) // per_worker_bytes )
637+ return min (cpu_cap , memory_cap )
638+
639+ def _ensure_worker_pool (self ) -> BenchmarkWorkerPool :
640+ if self ._worker_pool is None :
641+ self ._worker_pool = BenchmarkWorkerPool (num_workers = self ._pool_size ())
642+ return self ._worker_pool
643+
588644 def _validate_against_baseline (
589645 self , config : Config , output : object , args : Sequence [object ]
590646 ) -> bool :
@@ -676,7 +732,17 @@ def benchmark(
676732 configs = [all_configs [i ] for i in valid_indices ]
677733
678734 # Precompile phase
679- if self .settings .autotune_precompile :
735+ precompile_status : list [Literal ["ok" , "error" , "timeout" ]] = []
736+ compile_times : list [float | None ] = [None ] * len (configs )
737+ if self ._worker_precompile_enabled () and self .settings .autotune_precompile :
738+ precompile_desc = (
739+ f"{ desc } precompiling" if self .settings .autotune_progress_bar else None
740+ )
741+ is_workings , precompile_status , compile_times = (
742+ self ._worker_pool_precompile (configs , fns , precompile_desc )
743+ )
744+ futures = None
745+ elif self .settings .autotune_precompile :
680746 futures = list (
681747 starmap (
682748 self ._create_precompile_future ,
@@ -687,7 +753,6 @@ def benchmark(
687753 f"{ desc } precompiling" if self .settings .autotune_progress_bar else None
688754 )
689755 is_workings = PrecompileFuture .wait_for_all (futures , desc = precompile_desc )
690- precompile_status : list [Literal ["ok" , "error" , "timeout" ]] = []
691756 for future , ok in zip (futures , is_workings , strict = True ):
692757 reason = future .failure_reason
693758 if ok :
@@ -697,6 +762,7 @@ def benchmark(
697762 else :
698763 precompile_status .append ("error" )
699764 else :
765+ futures = None
700766 is_workings = [True ] * len (configs )
701767 precompile_status = ["ok" ] * len (configs )
702768
@@ -725,7 +791,7 @@ def benchmark(
725791 else None
726792 )
727793 else :
728- compile_time = None
794+ compile_time = compile_times [ index ]
729795 status : Literal [
730796 "ok" , "error" , "timeout" , "peer_compilation_fail" , "filtered"
731797 ]
@@ -954,6 +1020,65 @@ def _benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
9541020 self ._autotune_metrics .num_compile_failures += 1
9551021 return inf
9561022
1023+ def _worker_pool_precompile (
1024+ self ,
1025+ configs : list [Config ],
1026+ fns : list [CompiledConfig ],
1027+ desc : str | None ,
1028+ ) -> tuple [
1029+ list [bool ],
1030+ list [Literal ["ok" , "error" , "timeout" ]],
1031+ list [float | None ],
1032+ ]:
1033+ """Compile each config in the long-lived worker pool. Returns
1034+ ``(is_workings, statuses, compile_times)`` aligned with ``configs``."""
1035+ assert self ._precompile_args_path is not None
1036+ args_path = self ._precompile_args_path
1037+ timeout = float (self .settings .autotune_compile_timeout )
1038+
1039+ # Build PrecompileJobs; serialization failures count as compile failures.
1040+ jobs : list [PrecompileJob | None ] = []
1041+ for fn in fns :
1042+ try :
1043+ jobs .append (
1044+ PrecompileJob (
1045+ fn_spec = _serialize_compiled_fn (fn ), args_path = args_path
1046+ )
1047+ )
1048+ except RuntimeError :
1049+ jobs .append (None )
1050+
1051+ live_idxs = [i for i , j in enumerate (jobs ) if j is not None ]
1052+ live_jobs = cast ("list[Callable[[], object]]" , [jobs [i ] for i in live_idxs ])
1053+ t0 = time .perf_counter ()
1054+ live_results = self ._ensure_worker_pool ().map_jobs (live_jobs , timeout = timeout )
1055+ elapsed = time .perf_counter () - t0
1056+
1057+ is_workings = [False ] * len (configs )
1058+ statuses : list [Literal ["ok" , "error" , "timeout" ]] = ["error" ] * len (configs )
1059+ compile_times : list [float | None ] = [None ] * len (configs )
1060+ for idx , result in zip (live_idxs , live_results , strict = True ):
1061+ compile_times [idx ] = elapsed
1062+ if isinstance (result , BaseException ):
1063+ statuses [idx ] = (
1064+ "timeout"
1065+ if isinstance (result , BenchmarkSubprocessError )
1066+ and "timeout" in str (result ).lower ()
1067+ else "error"
1068+ )
1069+ self .log .debug (
1070+ f"Precompile worker failed for { configs [idx ]!r} : "
1071+ f"{ type (result ).__name__ } : { result } "
1072+ )
1073+ self ._autotune_metrics .num_compile_failures += 1
1074+ else :
1075+ is_workings [idx ] = True
1076+ statuses [idx ] = "ok"
1077+
1078+ if desc :
1079+ self .log (f"{ desc } 100% via worker pool ({ len (live_idxs )} configs)" )
1080+ return is_workings , statuses , compile_times
1081+
9571082 def _benchmark_function_subprocess (
9581083 self , config : Config , fn : CompiledConfig
9591084 ) -> float | None :
@@ -969,8 +1094,14 @@ def _benchmark_function_subprocess(
9691094 except RuntimeError :
9701095 return None
9711096
972- if self ._benchmark_worker is None :
973- self ._benchmark_worker = BenchmarkWorker (device = None )
1097+ # Prefer the pool's first worker if a pool is active so the same CUDA
1098+ # context that compiled also benchmarks (Triton cache hit, no recompile).
1099+ if self ._worker_pool is not None :
1100+ run_in_worker = lambda j , t : self ._worker_pool .run_one (j , timeout = t ) # noqa: E731
1101+ else :
1102+ if self ._benchmark_worker is None :
1103+ self ._benchmark_worker = BenchmarkWorker (device = None )
1104+ run_in_worker = lambda j , t : self ._benchmark_worker .run (j , timeout = t ) # noqa: E731
9741105
9751106 job = BenchmarkJob (
9761107 fn_spec = fn_spec ,
@@ -981,7 +1112,7 @@ def _benchmark_function_subprocess(
9811112 timeout = float (self .settings .autotune_benchmark_timeout )
9821113
9831114 try :
984- latency = self . _benchmark_worker . run (job , timeout = timeout )
1115+ latency = run_in_worker (job , timeout )
9851116 except BenchmarkSubprocessError as e :
9861117 # Timeout or unexpected worker exit; skip config and continue.
9871118 self .log .warning (f"Benchmark subprocess failed for { config !r} : { e } " )
0 commit comments