Skip to content

Commit c94c0a7

Browse files
committed
fix(bench): wire _init_worker into pools, force RAYON_NUM_THREADS=1
1 parent 72d0ebc commit c94c0a7

4 files changed

Lines changed: 43 additions & 8 deletions

File tree

benchmarks/adapters/calibrate.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,16 @@ def evaluate_grid(
9292

9393
from concurrent.futures.process import BrokenProcessPool
9494

95+
from benchmarks.common import _init_worker
96+
9597
def _make_pool() -> ProcessPoolExecutor:
9698
ctx = mp.get_context("spawn")
97-
p = ProcessPoolExecutor(max_workers=workers, mp_context=ctx, max_tasks_per_child=50)
99+
p = ProcessPoolExecutor(
100+
max_workers=workers,
101+
mp_context=ctx,
102+
max_tasks_per_child=50,
103+
initializer=_init_worker,
104+
)
98105
list(p.map(int, range(workers))) # eager-spawn all workers
99106
return p
100107

@@ -193,9 +200,16 @@ def evaluate_grid_cached( # noqa: C901 — pool teardown + per-cell demux + ret
193200
if needed:
194201
pending.append((inst, needed))
195202

203+
from benchmarks.common import _init_worker
204+
196205
def _make_pool() -> ProcessPoolExecutor:
197206
ctx = mp.get_context("spawn")
198-
p = ProcessPoolExecutor(max_workers=workers, mp_context=ctx, max_tasks_per_child=40)
207+
p = ProcessPoolExecutor(
208+
max_workers=workers,
209+
mp_context=ctx,
210+
max_tasks_per_child=40,
211+
initializer=_init_worker,
212+
)
199213
list(p.map(int, range(workers)))
200214
return p
201215

benchmarks/adapters/runner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,14 @@ def _drain(active_pool: ProcessPoolExecutor) -> None: # noqa: C901
297297
_drain(pool) # type: ignore[arg-type]
298298
return
299299

300+
from benchmarks.common import _init_worker
301+
300302
ctx = mp.get_context("spawn")
301-
with ProcessPoolExecutor(max_workers=workers, mp_context=ctx, max_tasks_per_child=50) as owned:
303+
with ProcessPoolExecutor(
304+
max_workers=workers,
305+
mp_context=ctx,
306+
max_tasks_per_child=50,
307+
initializer=_init_worker,
308+
) as owned:
302309
_drain(owned)
303310
owned.shutdown(wait=False, cancel_futures=True)

benchmarks/common.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,17 @@ def _init_worker() -> None:
302302

303303
warnings.filterwarnings("ignore", category=SyntaxWarning)
304304

305-
os.environ.setdefault("RAYON_NUM_THREADS", os.environ.get("BENCH_RAYON_THREADS", "1"))
306-
os.environ.setdefault("OMP_NUM_THREADS", "1")
307-
os.environ.setdefault("MKL_NUM_THREADS", "1")
308-
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
309-
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
305+
# Explicit assignment, not setdefault: parent may have inherited a
306+
# different value (e.g. ambient RAYON_NUM_THREADS=24 from a shell
307+
# profile) which would defeat the cap. With N workers each running
308+
# Rayon at full core count, we get N x cores threads contending for
309+
# the same cores — the calibration timeout breaks before the
310+
# algorithm finishes.
311+
os.environ["RAYON_NUM_THREADS"] = os.environ.get("BENCH_RAYON_THREADS", "1")
312+
os.environ["OMP_NUM_THREADS"] = "1"
313+
os.environ["MKL_NUM_THREADS"] = "1"
314+
os.environ["OPENBLAS_NUM_THREADS"] = "1"
315+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
310316

311317
for mod in (
312318
"_diffctx",

benchmarks/diffctx_eval_fn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ def _arm_diffctx_kill_switch(timeout_s: float):
8585
import threading
8686

8787
if timeout_s <= 0:
88+
# Trace once per worker so we know whether the env propagated.
89+
if not getattr(_arm_diffctx_kill_switch, "_warned_no_timeout", False):
90+
print(
91+
f"[worker pid={os.getpid()}] WARN: DIFFCTX_BENCH_TIMEOUT_SEC unset "
92+
f"or <=0, timer disabled (raw value: {os.environ.get('DIFFCTX_BENCH_TIMEOUT_SEC', 'MISSING')!r})",
93+
flush=True,
94+
)
95+
_arm_diffctx_kill_switch._warned_no_timeout = True # type: ignore[attr-defined]
8896
return None
8997
timer = threading.Timer(timeout_s, lambda: os._exit(137))
9098
timer.daemon = True

0 commit comments

Comments
 (0)