11from __future__ import annotations
22
33import heapq
4- from concurrent .futures import ProcessPoolExecutor
4+ import multiprocessing as mp
5+ import time
6+ from concurrent .futures import FIRST_COMPLETED , Future , ProcessPoolExecutor , wait
57from typing import TYPE_CHECKING , cast
68
79import numpy as np
@@ -282,7 +284,14 @@ def execute_self_play(
282284 add_noise = cfg_bool ("add_noise" )
283285 rng = np .random .default_rng (seed = cfg_int ("seed" ) + iteration )
284286 selfplay_workers = cfg_int ("selfplay_workers" )
285- log (f"[Iteration { iteration } ] Self-play episodes: { episodes } " , verbose_only = True )
287+ progress_every_s = max (5.0 , cfg_float ("selfplay_progress_every_s" ))
288+ episode_timeout_s = max (0.0 , cfg_float ("selfplay_episode_timeout_s" ))
289+ selfplay_start = time .perf_counter ()
290+ last_progress_log_s = selfplay_start
291+ log (
292+ f"[Iteration { iteration } ] self-play start episodes={ episodes } sims={ cfg_int ('mcts_sims' )} "
293+ f"workers={ selfplay_workers } " ,
294+ )
286295 curriculum_mix = get_curriculum_mix (iteration )
287296 log (
288297 " Opponent mix: "
@@ -353,6 +362,9 @@ def execute_self_play(
353362 }
354363 with ProcessPoolExecutor (
355364 max_workers = max_workers ,
365+ # Forking after CUDA/DDP warmup can deadlock on Linux notebooks.
366+ # We force "spawn" to keep self-play workers reliable on Kaggle.
367+ mp_context = mp .get_context ("spawn" ),
356368 initializer = _init_selfplay_process_worker ,
357369 initargs = (
358370 model_state_dict ,
@@ -361,20 +373,83 @@ def execute_self_play(
361373 cfg_int ("mcts_sims" ),
362374 ),
363375 ) as executor :
364- for (
365- (_ , opponent_type , heuristic_level , _ ),
366- episode_result ,
367- ) in zip (
368- episode_specs ,
369- executor .map (_run_episode_in_process_worker , worker_payloads ),
370- strict = True ,
376+ futures : dict [Future [tuple [list [tuple [np .ndarray , np .ndarray , int ]], int , int ]], tuple [int , str , str ]] = {}
377+ submitted_at : dict [
378+ Future [tuple [list [tuple [np .ndarray , np .ndarray , int ]], int , int ]],
379+ float ,
380+ ] = {}
381+ ordered_results : list [
382+ tuple [int , str , str , list [tuple [np .ndarray , np .ndarray , int ]], int , int ]
383+ ] = []
384+ for idx , ((_ , opponent_type , heuristic_level , _ ), payload ) in enumerate (
385+ zip (episode_specs , worker_payloads , strict = True ),
386+ start = 1 ,
371387 ):
372- game_history , winner , turn_idx = episode_result
388+ future = executor .submit (_run_episode_in_process_worker , payload )
389+ futures [future ] = (idx , opponent_type , heuristic_level )
390+ submitted_at [future ] = time .perf_counter ()
391+
392+ pending = set (futures )
393+ while pending :
394+ done , pending = wait (
395+ pending ,
396+ timeout = progress_every_s ,
397+ return_when = FIRST_COMPLETED ,
398+ )
399+ now_s = time .perf_counter ()
400+ if not done :
401+ if (now_s - last_progress_log_s ) >= progress_every_s :
402+ log (
403+ f"[Iteration { iteration } ] self-play progress "
404+ f"{ len (ordered_results )} /{ episodes } elapsed={ now_s - selfplay_start :.0f} s" ,
405+ )
406+ last_progress_log_s = now_s
407+ if episode_timeout_s <= 0.0 :
408+ continue
409+ oldest_pending_s = max (now_s - submitted_at [fut ] for fut in pending )
410+ if oldest_pending_s <= episode_timeout_s :
411+ continue
412+ raise TimeoutError (
413+ "Parallel self-play stalled: "
414+ f"oldest pending episode exceeded { episode_timeout_s :.0f} s." ,
415+ )
416+
417+ for future in done :
418+ idx , opponent_type , heuristic_level = futures [future ]
419+ game_history , winner , turn_idx = future .result ()
420+ ordered_results .append (
421+ (
422+ idx ,
423+ opponent_type ,
424+ heuristic_level ,
425+ game_history ,
426+ winner ,
427+ turn_idx ,
428+ )
429+ )
430+ submitted_at .pop (future , None )
431+
432+ if (now_s - last_progress_log_s ) >= progress_every_s :
433+ log (
434+ f"[Iteration { iteration } ] self-play progress "
435+ f"{ len (ordered_results )} /{ episodes } elapsed={ now_s - selfplay_start :.0f} s" ,
436+ )
437+ last_progress_log_s = now_s
438+
439+ ordered_results .sort (key = lambda item : item [0 ])
440+ for (
441+ _idx ,
442+ opponent_type ,
443+ heuristic_level ,
444+ game_history ,
445+ winner ,
446+ turn_idx ,
447+ ) in ordered_results :
373448 episode_results .append (
374449 (opponent_type , heuristic_level , game_history , winner , turn_idx )
375450 )
376451 used_parallel = True
377- log (f" Self -play process workers active: { max_workers } " , verbose_only = True )
452+ log (f"[Iteration { iteration } ] self -play process workers active: { max_workers } " )
378453 except Exception as exc :
379454 handle_parallel_selfplay_failure (exc )
380455 episode_results .clear ()
@@ -392,6 +467,13 @@ def execute_self_play(
392467 model_player = model_player ,
393468 )
394469 episode_results .append ((opponent_type , heuristic_level , game_history , winner , turn_idx ))
470+ now_s = time .perf_counter ()
471+ if (now_s - last_progress_log_s ) >= progress_every_s :
472+ log (
473+ f"[Iteration { iteration } ] self-play progress "
474+ f"{ len (episode_results )} /{ episodes } elapsed={ now_s - selfplay_start :.0f} s" ,
475+ )
476+ last_progress_log_s = now_s
395477
396478 for episode_idx , (opponent_type , heuristic_level , game_history , winner , turn_idx ) in enumerate (
397479 episode_results ,
0 commit comments