Skip to content

Commit a77ace2

Browse files
committed
Add fail-fast guards for long-running training failures
1 parent 60ecb53 commit a77ace2

6 files changed

Lines changed: 99 additions & 19 deletions

File tree

src/training/checkpointing.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import os
5+
from concurrent.futures import Future
56
from datetime import datetime, timezone
67
from pathlib import Path
78
from typing import TYPE_CHECKING, Any
@@ -275,6 +276,26 @@ def ensure_hf_ready(checkpointer: HuggingFaceCheckpointer | None) -> None:
275276
)
276277

277278

279+
def drain_completed_hf_uploads(
280+
futures: list[Future[None]],
281+
*,
282+
fail_on_error: bool,
283+
) -> list[Future[None]]:
284+
"""Return pending futures while surfacing completed upload failures early."""
285+
pending: list[Future[None]] = []
286+
for future in futures:
287+
if not future.done():
288+
pending.append(future)
289+
continue
290+
try:
291+
future.result()
292+
except Exception as exc:
293+
if fail_on_error:
294+
raise RuntimeError("HF upload future failed. Aborting early.") from exc
295+
log(f"HF upload failed (continuing): {exc}")
296+
return pending
297+
298+
278299
def should_save_iteration_checkpoint(iteration: int, total_iterations: int, save_every: int) -> bool:
279300
"""Always persist the final iteration even when it is not divisible by save_every."""
280301
if iteration >= total_iterations:
@@ -286,6 +307,7 @@ def should_save_iteration_checkpoint(iteration: int, total_iterations: int, save
286307
"HuggingFaceCheckpointer",
287308
"cleanup_local_checkpoints",
288309
"cleanup_old_log_versions",
310+
"drain_completed_hf_uploads",
289311
"ensure_hf_ready",
290312
"init_hf_checkpointer",
291313
"should_save_iteration_checkpoint",

src/training/config_runtime.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
"mcts_use_amp": True,
7474
"mcts_cache_size": 100_000,
7575
"mcts_leaf_batch_size": 32,
76+
"fail_on_selfplay_parallel_error": True,
77+
"fail_on_hf_upload_error": True,
7678
"opponent_self_prob": 0.45,
7779
"opponent_heuristic_prob": 0.5,
7880
"opponent_random_prob": 0.05,
@@ -154,6 +156,8 @@ def parse_args() -> argparse.Namespace:
154156
parser.add_argument("--eval-games", type=int, default=None)
155157
parser.add_argument("--eval-sims", type=int, default=None)
156158
parser.add_argument("--selfplay-workers", type=int, default=None)
159+
parser.add_argument("--allow-selfplay-fallback", action="store_true")
160+
parser.add_argument("--allow-hf-upload-errors", action="store_true")
157161
parser.add_argument("--warmup-games", type=int, default=None)
158162
parser.add_argument("--warmup-epochs", type=int, default=None)
159163
parser.add_argument(
@@ -256,6 +260,10 @@ def apply_cli_overrides(args: argparse.Namespace) -> None:
256260
CONFIG["eval_sims"] = max(8, args.eval_sims)
257261
if args.selfplay_workers is not None:
258262
CONFIG["selfplay_workers"] = max(1, args.selfplay_workers)
263+
if args.allow_selfplay_fallback:
264+
CONFIG["fail_on_selfplay_parallel_error"] = False
265+
if args.allow_hf_upload_errors:
266+
CONFIG["fail_on_hf_upload_error"] = False
259267
if args.warmup_games is not None:
260268
CONFIG["warmup_games"] = max(0, args.warmup_games)
261269
if args.warmup_epochs is not None:

src/training/selfplay_runtime.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,15 @@ def history_to_examples(
247247
return examples
248248

249249

250+
def handle_parallel_selfplay_failure(exc: Exception) -> None:
251+
if cfg_bool("fail_on_selfplay_parallel_error"):
252+
raise RuntimeError(
253+
"Process self-play failed with parallel workers. "
254+
"Aborting instead of silently falling back to sequential mode.",
255+
) from exc
256+
log(f" Process self-play failed, falling back to sequential mode: {exc}")
257+
258+
250259
def execute_self_play(
251260
system: AtaxxZero,
252261
buffer: ReplayBuffer,
@@ -367,9 +376,7 @@ def execute_self_play(
367376
used_parallel = True
368377
log(f" Self-play process workers active: {max_workers}", verbose_only=True)
369378
except Exception as exc:
370-
log(
371-
f" Process self-play failed, falling back to sequential mode: {exc}",
372-
)
379+
handle_parallel_selfplay_failure(exc)
373380
episode_results.clear()
374381

375382
if not used_parallel:
@@ -422,4 +429,5 @@ def execute_self_play(
422429

423430
__all__ = [
424431
"execute_self_play",
432+
"handle_parallel_selfplay_failure",
425433
]

tests/test_training_checkpointing.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

33
import unittest
4+
from concurrent.futures import Future
45

56
from training.checkpointing import (
67
HuggingFaceCheckpointer,
8+
drain_completed_hf_uploads,
79
ensure_hf_ready,
810
should_save_iteration_checkpoint,
911
)
@@ -34,6 +36,25 @@ def test_ensure_hf_ready_noop_when_hf_disabled(self) -> None:
3436
CONFIG["hf_enabled"] = False
3537
ensure_hf_ready(None)
3638

39+
def test_drain_completed_hf_uploads_keeps_pending_only(self) -> None:
40+
done: Future[None] = Future()
41+
done.set_result(None)
42+
pending: Future[None] = Future()
43+
remaining = drain_completed_hf_uploads([done, pending], fail_on_error=True)
44+
self.assertEqual(remaining, [pending])
45+
46+
def test_drain_completed_hf_uploads_raises_when_fail_fast_enabled(self) -> None:
47+
failed: Future[None] = Future()
48+
failed.set_exception(RuntimeError("upload failed"))
49+
with self.assertRaises(RuntimeError):
50+
drain_completed_hf_uploads([failed], fail_on_error=True)
51+
52+
def test_drain_completed_hf_uploads_continues_when_fail_fast_disabled(self) -> None:
53+
failed: Future[None] = Future()
54+
failed.set_exception(RuntimeError("upload failed"))
55+
remaining = drain_completed_hf_uploads([failed], fail_on_error=False)
56+
self.assertEqual(remaining, [])
57+
3758
def test_should_save_iteration_checkpoint_on_schedule(self) -> None:
3859
self.assertTrue(
3960
should_save_iteration_checkpoint(
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
5+
from training.config_runtime import CONFIG
6+
from training.selfplay_runtime import handle_parallel_selfplay_failure
7+
8+
9+
class TestTrainingSelfplayRuntime(unittest.TestCase):
10+
def setUp(self) -> None:
11+
self._backup = dict(CONFIG)
12+
13+
def tearDown(self) -> None:
14+
CONFIG.clear()
15+
CONFIG.update(self._backup)
16+
17+
def test_handle_parallel_selfplay_failure_raises_in_fail_fast_mode(self) -> None:
18+
CONFIG["fail_on_selfplay_parallel_error"] = True
19+
with self.assertRaises(RuntimeError):
20+
handle_parallel_selfplay_failure(RuntimeError("pool broke"))
21+
22+
def test_handle_parallel_selfplay_failure_allows_fallback_when_configured(self) -> None:
23+
CONFIG["fail_on_selfplay_parallel_error"] = False
24+
handle_parallel_selfplay_failure(RuntimeError("pool broke"))
25+
26+
27+
if __name__ == "__main__":
28+
unittest.main()

train.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from training.checkpointing import ( # noqa: E402
2727
cleanup_local_checkpoints,
2828
cleanup_old_log_versions,
29+
drain_completed_hf_uploads,
2930
ensure_hf_ready,
3031
init_hf_checkpointer,
3132
should_save_iteration_checkpoint,
@@ -349,6 +350,8 @@ def main() -> None:
349350

350351
try:
351352
for iteration in range(start_iteration + 1, iterations + 1):
353+
if hf_upload_futures:
354+
hf_upload_futures = drain_completed_hf_uploads(hf_upload_futures, fail_on_error=cfg_bool("fail_on_hf_upload_error"))
352355
epoch_pulse.set_iteration(iteration)
353356
selfplay_start = time.perf_counter()
354357
selfplay_stats = execute_self_play(
@@ -358,6 +361,8 @@ def main() -> None:
358361
device=device,
359362
)
360363
selfplay_s = time.perf_counter() - selfplay_start
364+
if len(buffer) == 0:
365+
raise RuntimeError("Replay buffer is empty after self-play; aborting early.")
361366

362367
train_loader = _build_train_loader(buffer, device=device)
363368
val_loader = _build_val_loader(buffer, device=device)
@@ -409,10 +414,7 @@ def main() -> None:
409414
best_path = checkpoint_dir / "best_eval.ckpt"
410415
trainer.save_checkpoint(str(best_path))
411416
except Exception as exc:
412-
monitor.log_warning(
413-
iteration=iteration,
414-
message=f"eval failed, continuing training: {exc}",
415-
)
417+
monitor.log_warning(iteration=iteration, message=f"eval failed, continuing training: {exc}")
416418

417419
if not should_save_iteration_checkpoint(
418420
iteration=iteration,
@@ -430,10 +432,7 @@ def main() -> None:
430432
keep_last_n=cfg_int("keep_last_n_local_checkpoints"),
431433
)
432434
except OSError:
433-
monitor.log_warning(
434-
iteration=iteration,
435-
message="local checkpoint save failed.",
436-
)
435+
monitor.log_warning(iteration=iteration, message="local checkpoint save failed.")
437436

438437
if hf_checkpointer is not None:
439438
try:
@@ -458,10 +457,7 @@ def main() -> None:
458457
keep_last_n=cfg_int("keep_last_n_hf_checkpoints"),
459458
)
460459
hf_upload_futures.append(future)
461-
monitor.log_warning(
462-
iteration=iteration,
463-
message=f"HF upload queued for iteration {iteration}.",
464-
)
460+
monitor.log_warning(iteration=iteration, message=f"HF upload queued for iteration {iteration}.")
465461
else:
466462
hf_checkpointer.upload_checkpoint_files(
467463
iteration=iteration,
@@ -470,10 +466,7 @@ def main() -> None:
470466
metadata_path=metadata_path,
471467
keep_last_n=cfg_int("keep_last_n_hf_checkpoints"),
472468
)
473-
monitor.log_warning(
474-
iteration=iteration,
475-
message=f"HF checkpoint uploaded for iteration {iteration}.",
476-
)
469+
monitor.log_warning(iteration=iteration, message=f"HF checkpoint uploaded for iteration {iteration}.")
477470
except (OSError, ValueError):
478471
monitor.log_warning(iteration=iteration, message="HF upload failed for this iteration.")
479472
if cfg_bool("export_onnx"):

0 commit comments

Comments
 (0)