Skip to content

Commit 023dcfe

Browse files
committed
Harden self-play runtime with progress pulses and stall timeouts
1 parent 95cfe80 commit 023dcfe

2 files changed

Lines changed: 105 additions & 11 deletions

File tree

src/training/config_runtime.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@
9191
"eval_sims": 400,
9292
"eval_heuristic_level": "hard",
9393
"selfplay_workers": 8,
94+
"selfplay_progress_every_s": 120.0,
95+
"selfplay_episode_timeout_s": 1800.0,
9496
"compile_model": True,
9597
"quiet_mode": False,
9698
"warmup_games": 600,
@@ -158,6 +160,8 @@ def parse_args() -> argparse.Namespace:
158160
parser.add_argument("--eval-games", type=int, default=None)
159161
parser.add_argument("--eval-sims", type=int, default=None)
160162
parser.add_argument("--selfplay-workers", type=int, default=None)
163+
parser.add_argument("--selfplay-progress-every-s", type=float, default=None)
164+
parser.add_argument("--selfplay-episode-timeout-s", type=float, default=None)
161165
parser.add_argument("--allow-selfplay-fallback", action="store_true")
162166
parser.add_argument("--allow-hf-upload-errors", action="store_true")
163167
parser.add_argument("--warmup-games", type=int, default=None)
@@ -264,6 +268,10 @@ def apply_cli_overrides(args: argparse.Namespace) -> None:
264268
CONFIG["eval_sims"] = max(8, args.eval_sims)
265269
if args.selfplay_workers is not None:
266270
CONFIG["selfplay_workers"] = max(1, args.selfplay_workers)
271+
if args.selfplay_progress_every_s is not None:
272+
CONFIG["selfplay_progress_every_s"] = max(5.0, args.selfplay_progress_every_s)
273+
if args.selfplay_episode_timeout_s is not None:
274+
CONFIG["selfplay_episode_timeout_s"] = max(0.0, args.selfplay_episode_timeout_s)
267275
if args.allow_selfplay_fallback:
268276
CONFIG["fail_on_selfplay_parallel_error"] = False
269277
if args.allow_hf_upload_errors:
@@ -356,6 +364,10 @@ def validate_config() -> None:
356364
raise ValueError("CONFIG['max_pending_hf_uploads'] must be > 0.")
357365
if cfg_float("hf_upload_future_timeout_s") <= 0.0:
358366
raise ValueError("CONFIG['hf_upload_future_timeout_s'] must be > 0.")
367+
if cfg_float("selfplay_progress_every_s") <= 0.0:
368+
raise ValueError("CONFIG['selfplay_progress_every_s'] must be > 0.")
369+
if cfg_float("selfplay_episode_timeout_s") < 0.0:
370+
raise ValueError("CONFIG['selfplay_episode_timeout_s'] must be >= 0.")
359371

360372
opp_sum = (
361373
cfg_float("opponent_self_prob")

src/training/selfplay_runtime.py

Lines changed: 93 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

33
import heapq
4-
from concurrent.futures import ProcessPoolExecutor
4+
import multiprocessing as mp
5+
import time
6+
from concurrent.futures import FIRST_COMPLETED, Future, ProcessPoolExecutor, wait
57
from typing import TYPE_CHECKING, cast
68

79
import numpy as np
@@ -282,7 +284,14 @@ def execute_self_play(
282284
add_noise = cfg_bool("add_noise")
283285
rng = np.random.default_rng(seed=cfg_int("seed") + iteration)
284286
selfplay_workers = cfg_int("selfplay_workers")
285-
log(f"[Iteration {iteration}] Self-play episodes: {episodes}", verbose_only=True)
287+
progress_every_s = max(5.0, cfg_float("selfplay_progress_every_s"))
288+
episode_timeout_s = max(0.0, cfg_float("selfplay_episode_timeout_s"))
289+
selfplay_start = time.perf_counter()
290+
last_progress_log_s = selfplay_start
291+
log(
292+
f"[Iteration {iteration}] self-play start episodes={episodes} sims={cfg_int('mcts_sims')} "
293+
f"workers={selfplay_workers}",
294+
)
286295
curriculum_mix = get_curriculum_mix(iteration)
287296
log(
288297
" Opponent mix: "
@@ -353,6 +362,9 @@ def execute_self_play(
353362
}
354363
with ProcessPoolExecutor(
355364
max_workers=max_workers,
365+
# Forking after CUDA/DDP warmup can deadlock on Linux notebooks.
366+
# We force "spawn" to keep self-play workers reliable on Kaggle.
367+
mp_context=mp.get_context("spawn"),
356368
initializer=_init_selfplay_process_worker,
357369
initargs=(
358370
model_state_dict,
@@ -361,20 +373,83 @@ def execute_self_play(
361373
cfg_int("mcts_sims"),
362374
),
363375
) as executor:
364-
for (
365-
(_, opponent_type, heuristic_level, _),
366-
episode_result,
367-
) in zip(
368-
episode_specs,
369-
executor.map(_run_episode_in_process_worker, worker_payloads),
370-
strict=True,
376+
futures: dict[Future[tuple[list[tuple[np.ndarray, np.ndarray, int]], int, int]], tuple[int, str, str]] = {}
377+
submitted_at: dict[
378+
Future[tuple[list[tuple[np.ndarray, np.ndarray, int]], int, int]],
379+
float,
380+
] = {}
381+
ordered_results: list[
382+
tuple[int, str, str, list[tuple[np.ndarray, np.ndarray, int]], int, int]
383+
] = []
384+
for idx, ((_, opponent_type, heuristic_level, _), payload) in enumerate(
385+
zip(episode_specs, worker_payloads, strict=True),
386+
start=1,
371387
):
372-
game_history, winner, turn_idx = episode_result
388+
future = executor.submit(_run_episode_in_process_worker, payload)
389+
futures[future] = (idx, opponent_type, heuristic_level)
390+
submitted_at[future] = time.perf_counter()
391+
392+
pending = set(futures)
393+
while pending:
394+
done, pending = wait(
395+
pending,
396+
timeout=progress_every_s,
397+
return_when=FIRST_COMPLETED,
398+
)
399+
now_s = time.perf_counter()
400+
if not done:
401+
if (now_s - last_progress_log_s) >= progress_every_s:
402+
log(
403+
f"[Iteration {iteration}] self-play progress "
404+
f"{len(ordered_results)}/{episodes} elapsed={now_s - selfplay_start:.0f}s",
405+
)
406+
last_progress_log_s = now_s
407+
if episode_timeout_s <= 0.0:
408+
continue
409+
oldest_pending_s = max(now_s - submitted_at[fut] for fut in pending)
410+
if oldest_pending_s <= episode_timeout_s:
411+
continue
412+
raise TimeoutError(
413+
"Parallel self-play stalled: "
414+
f"oldest pending episode exceeded {episode_timeout_s:.0f}s.",
415+
)
416+
417+
for future in done:
418+
idx, opponent_type, heuristic_level = futures[future]
419+
game_history, winner, turn_idx = future.result()
420+
ordered_results.append(
421+
(
422+
idx,
423+
opponent_type,
424+
heuristic_level,
425+
game_history,
426+
winner,
427+
turn_idx,
428+
)
429+
)
430+
submitted_at.pop(future, None)
431+
432+
if (now_s - last_progress_log_s) >= progress_every_s:
433+
log(
434+
f"[Iteration {iteration}] self-play progress "
435+
f"{len(ordered_results)}/{episodes} elapsed={now_s - selfplay_start:.0f}s",
436+
)
437+
last_progress_log_s = now_s
438+
439+
ordered_results.sort(key=lambda item: item[0])
440+
for (
441+
_idx,
442+
opponent_type,
443+
heuristic_level,
444+
game_history,
445+
winner,
446+
turn_idx,
447+
) in ordered_results:
373448
episode_results.append(
374449
(opponent_type, heuristic_level, game_history, winner, turn_idx)
375450
)
376451
used_parallel = True
377-
log(f" Self-play process workers active: {max_workers}", verbose_only=True)
452+
log(f"[Iteration {iteration}] self-play process workers active: {max_workers}")
378453
except Exception as exc:
379454
handle_parallel_selfplay_failure(exc)
380455
episode_results.clear()
@@ -392,6 +467,13 @@ def execute_self_play(
392467
model_player=model_player,
393468
)
394469
episode_results.append((opponent_type, heuristic_level, game_history, winner, turn_idx))
470+
now_s = time.perf_counter()
471+
if (now_s - last_progress_log_s) >= progress_every_s:
472+
log(
473+
f"[Iteration {iteration}] self-play progress "
474+
f"{len(episode_results)}/{episodes} elapsed={now_s - selfplay_start:.0f}s",
475+
)
476+
last_progress_log_s = now_s
395477

396478
for episode_idx, (opponent_type, heuristic_level, game_history, winner, turn_idx) in enumerate(
397479
episode_results,

0 commit comments

Comments
 (0)