Skip to content

Commit f2f7ce4

Browse files
committed
_shutdown_pool now runs pool.terminate() in a daemon thread with a _TERMINATE_GRACE_SECONDS wait
1 parent 5b90399 commit f2f7ce4

1 file changed

Lines changed: 12 additions & 3 deletions

File tree

src/maxtext/trainers/post_train/rl/math_verify_pool.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import itertools
2929
import multiprocessing
3030
import os
31+
import threading
3132
import time
3233
import uuid
3334
from etils import epath
@@ -155,8 +156,16 @@ def _shutdown_pool():
155156
if w.is_alive():
156157
w.kill()
157158
w.join(timeout=1.0)
158-
pool.terminate()
159-
pool.join()
159+
# Workers SIGKILLed mid-write leave the outqueue lock orphaned, so
160+
# pool.terminate() / pool.join() block forever on the internal
161+
# _result_handler / _task_handler threads. Run terminate in a daemon
162+
# thread with a bounded wait: pool._state flips to TERMINATE so the
163+
# worker-handler stops spawning replacements, and we return even if the
164+
# handler threads never unblock. Those threads leak, but they are daemon
165+
# and cheap; a stuck trainer is not.
166+
t = threading.Thread(target=pool.terminate, daemon=True)
167+
t.start()
168+
t.join(timeout=_TERMINATE_GRACE_SECONDS)
160169
except Exception:
161170
pass
162171

@@ -176,7 +185,7 @@ def _get_pool(num_procs):
176185
return _POOL
177186

178187

179-
# ensures global worker pool is cleanly shut down when progam finishes execution
188+
# ensures global worker pool is cleanly shut down when program finishes execution
180189
atexit.register(_shutdown_pool)
181190

182191

0 commit comments

Comments
 (0)