Skip to content

Commit 60ecb53

Browse files
committed
Fail fast DDP startup and HF checkpoint readiness
1 parent 1950674 commit 60ecb53

6 files changed

Lines changed: 124 additions & 25 deletions

File tree

src/training/checkpointing.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,23 @@ def init_hf_checkpointer() -> HuggingFaceCheckpointer | None:
258258
)
259259

260260

261+
def ensure_hf_ready(checkpointer: HuggingFaceCheckpointer | None) -> None:
262+
"""Fail fast when HF checkpointing was requested but cannot be initialized."""
263+
if not cfg_bool("hf_enabled"):
264+
return
265+
if checkpointer is not None:
266+
log(
267+
"HF checkpointing enabled: "
268+
f"repo={cfg_str('hf_repo_id').strip()} run_id={cfg_str('hf_run_id').strip()}",
269+
)
270+
return
271+
token_env = cfg_str("hf_token_env")
272+
raise RuntimeError(
273+
"HF checkpointing requested (--hf) but initialization failed. "
274+
f"Check token env '{token_env}', hf_repo_id, and hf_run_id before starting.",
275+
)
276+
277+
261278
def should_save_iteration_checkpoint(iteration: int, total_iterations: int, save_every: int) -> bool:
262279
"""Always persist the final iteration even when it is not divisible by save_every."""
263280
if iteration >= total_iterations:
@@ -269,6 +286,7 @@ def should_save_iteration_checkpoint(iteration: int, total_iterations: int, save
269286
"HuggingFaceCheckpointer",
270287
"cleanup_local_checkpoints",
271288
"cleanup_old_log_versions",
289+
"ensure_hf_ready",
272290
"init_hf_checkpointer",
273291
"should_save_iteration_checkpoint",
274292
]

src/training/config_runtime.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
"trainer_strategy": "auto",
7070
"trainer_precision": "bf16-mixed",
7171
"trainer_benchmark": True,
72+
"ddp_timeout_seconds": 180,
7273
"mcts_use_amp": True,
7374
"mcts_cache_size": 100_000,
7475
"mcts_leaf_batch_size": 32,
@@ -121,6 +122,7 @@ def parse_args() -> argparse.Namespace:
121122
parser.add_argument("--keep-log-versions", type=int, default=None)
122123
parser.add_argument("--devices", type=int, default=None)
123124
parser.add_argument("--strategy", default=None)
125+
parser.add_argument("--ddp-timeout-s", type=int, default=None)
124126
parser.add_argument(
125127
"--precision",
126128
choices=["16-mixed", "bf16-mixed", "32-true"],
@@ -206,6 +208,8 @@ def apply_cli_overrides(args: argparse.Namespace) -> None:
206208
CONFIG["trainer_devices"] = max(1, args.devices)
207209
if args.strategy is not None:
208210
CONFIG["trainer_strategy"] = args.strategy
211+
if args.ddp_timeout_s is not None:
212+
CONFIG["ddp_timeout_seconds"] = max(30, args.ddp_timeout_s)
209213
if args.precision is not None:
210214
CONFIG["trainer_precision"] = args.precision
211215
if args.num_workers is not None:
@@ -330,6 +334,8 @@ def validate_config() -> None:
330334
raise ValueError("CONFIG['value_loss_coeff'] must be >= 0.")
331335
if cfg_int("mcts_cache_size") < 0:
332336
raise ValueError("CONFIG['mcts_cache_size'] must be >= 0.")
337+
if cfg_int("ddp_timeout_seconds") <= 0:
338+
raise ValueError("CONFIG['ddp_timeout_seconds'] must be > 0.")
333339

334340
opp_sum = (
335341
cfg_float("opponent_self_prob")

src/training/trainer_runtime.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

3+
from datetime import timedelta
4+
35
import pytorch_lightning as pl
46
import torch
57
from pytorch_lightning import Callback
68
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
79
from pytorch_lightning.loggers import TensorBoardLogger
10+
from pytorch_lightning.strategies import DDPStrategy
811

912
from training.config_runtime import (
1013
TrainerPrecision,
@@ -43,6 +46,15 @@ def is_ddp_rendezvous_timeout(exc: BaseException) -> bool:
4346
)
4447

4548

49+
def resolve_trainer_strategy(strategy: str) -> str | DDPStrategy:
50+
timeout = timedelta(seconds=max(30, cfg_int("ddp_timeout_seconds")))
51+
if strategy == "ddp":
52+
return DDPStrategy(timeout=timeout, start_method="popen")
53+
if strategy == "ddp_spawn":
54+
return DDPStrategy(timeout=timeout, start_method="spawn")
55+
return strategy
56+
57+
4658
def build_trainer(
4759
*,
4860
epochs: int,
@@ -59,11 +71,12 @@ def build_trainer(
5971
callbacks: list[Callback] = [checkpoint_callback, lr_monitor]
6072
if extra_callbacks is not None:
6173
callbacks.extend(extra_callbacks)
74+
resolved_strategy = resolve_trainer_strategy(strategy)
6275
return pl.Trainer(
6376
max_epochs=epochs,
6477
accelerator=accelerator,
6578
devices=devices,
66-
strategy=strategy,
79+
strategy=resolved_strategy,
6780
precision=precision,
6881
benchmark=benchmark,
6982
callbacks=callbacks,
@@ -107,4 +120,5 @@ def export_onnx(model: torch.nn.Module, path: str, device: str) -> None:
107120
"is_ddp_rendezvous_timeout",
108121
"resolve_trainer_hw",
109122
"resolve_trainer_precision",
123+
"resolve_trainer_strategy",
110124
]

tests/test_training_checkpointing.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,36 @@
44

55
from training.checkpointing import (
66
HuggingFaceCheckpointer,
7+
ensure_hf_ready,
78
should_save_iteration_checkpoint,
89
)
10+
from training.config_runtime import CONFIG
911

1012

1113
class TestTrainingCheckpointing(unittest.TestCase):
14+
def setUp(self) -> None:
15+
self._backup = dict(CONFIG)
16+
17+
def tearDown(self) -> None:
18+
CONFIG.clear()
19+
CONFIG.update(self._backup)
20+
1221
def test_repo_path_is_namespaced_by_run_id(self) -> None:
1322
checkpointer = object.__new__(HuggingFaceCheckpointer)
1423
checkpointer.run_id = "policy_spatial_v1"
1524
repo_path = checkpointer._repo_path("model_iter_040.pt")
1625
self.assertEqual(repo_path, "runs/policy_spatial_v1/model_iter_040.pt")
1726

27+
def test_ensure_hf_ready_raises_when_hf_enabled_without_checkpointer(self) -> None:
28+
CONFIG["hf_enabled"] = True
29+
CONFIG["hf_token_env"] = "HF_TOKEN" # noqa: S105 - test fixture value, not a secret.
30+
with self.assertRaises(RuntimeError):
31+
ensure_hf_ready(None)
32+
33+
def test_ensure_hf_ready_noop_when_hf_disabled(self) -> None:
34+
CONFIG["hf_enabled"] = False
35+
ensure_hf_ready(None)
36+
1837
def test_should_save_iteration_checkpoint_on_schedule(self) -> None:
1938
self.assertTrue(
2039
should_save_iteration_checkpoint(
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
from datetime import timedelta
5+
6+
from pytorch_lightning.strategies import DDPStrategy
7+
8+
from training.config_runtime import CONFIG
9+
from training.trainer_runtime import resolve_trainer_strategy
10+
11+
12+
class TestTrainingTrainerRuntime(unittest.TestCase):
13+
def setUp(self) -> None:
14+
self._backup = dict(CONFIG)
15+
16+
def tearDown(self) -> None:
17+
CONFIG.clear()
18+
CONFIG.update(self._backup)
19+
20+
def test_resolve_trainer_strategy_ddp_uses_configured_timeout(self) -> None:
21+
CONFIG["ddp_timeout_seconds"] = 75
22+
resolved = resolve_trainer_strategy("ddp")
23+
self.assertIsInstance(resolved, DDPStrategy)
24+
strategy = resolved
25+
self.assertEqual(strategy._timeout, timedelta(seconds=75))
26+
self.assertEqual(strategy._start_method, "popen")
27+
28+
def test_resolve_trainer_strategy_ddp_spawn_uses_spawn_start_method(self) -> None:
29+
CONFIG["ddp_timeout_seconds"] = 90
30+
resolved = resolve_trainer_strategy("ddp_spawn")
31+
self.assertIsInstance(resolved, DDPStrategy)
32+
strategy = resolved
33+
self.assertEqual(strategy._timeout, timedelta(seconds=90))
34+
self.assertEqual(strategy._start_method, "spawn")
35+
36+
def test_resolve_trainer_strategy_passthrough_for_auto(self) -> None:
37+
resolved = resolve_trainer_strategy("auto")
38+
self.assertEqual(resolved, "auto")
39+
40+
41+
if __name__ == "__main__":
42+
unittest.main()

train.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from training.checkpointing import ( # noqa: E402
2727
cleanup_local_checkpoints,
2828
cleanup_old_log_versions,
29+
ensure_hf_ready,
2930
init_hf_checkpointer,
3031
should_save_iteration_checkpoint,
3132
)
@@ -198,11 +199,11 @@ def _run_warmup_if_needed(
198199
optimizer_transfer: OptimizerStateTransfer,
199200
monitor: TrainingMonitor,
200201
epoch_pulse: EpochPulseCallback,
201-
) -> None:
202+
) -> tuple[str, int, str, TrainerPrecision]:
202203
warmup_games = cfg_int("warmup_games")
203204
warmup_epochs = cfg_int("warmup_epochs")
204205
if start_iteration != 0 or warmup_games <= 0 or warmup_epochs <= 0:
205-
return
206+
return trainer_accelerator, trainer_devices, trainer_strategy, trainer_precision
206207

207208
# Warmup seeds the policy with legal, sensible moves before self-play noise.
208209
warmup_rng = torch.Generator().manual_seed(cfg_int("seed"))
@@ -216,24 +217,28 @@ def _run_warmup_if_needed(
216217
monitor.log_warmup(examples=len(warmup_examples), games=warmup_games)
217218
train_loader = _build_train_loader(buffer, device=device)
218219
val_loader = _build_val_loader(buffer, device=device)
219-
warmup_trainer = build_trainer(
220+
(
221+
_warmup_trainer,
222+
trainer_accelerator,
223+
trainer_devices,
224+
trainer_strategy,
225+
trainer_precision,
226+
) = _fit_with_ddp_fallback(
227+
system=system,
228+
train_loader=train_loader,
229+
val_loader=val_loader,
220230
epochs=warmup_epochs,
221-
accelerator=trainer_accelerator,
222-
devices=trainer_devices,
223-
strategy=trainer_strategy,
224-
precision=trainer_precision,
225-
benchmark=cfg_bool("trainer_benchmark"),
231+
trainer_accelerator=trainer_accelerator,
232+
trainer_devices=trainer_devices,
233+
trainer_strategy=trainer_strategy,
234+
trainer_precision=trainer_precision,
226235
checkpoint_callback=checkpoint_callback,
227236
lr_monitor=lr_monitor,
228237
logger=logger,
229-
extra_callbacks=[optimizer_transfer, epoch_pulse],
230-
)
231-
system.train()
232-
warmup_trainer.fit(
233-
model=system,
234-
train_dataloaders=train_loader,
235-
val_dataloaders=val_loader,
238+
optimizer_transfer=optimizer_transfer,
239+
epoch_pulse=epoch_pulse,
236240
)
241+
return trainer_accelerator, trainer_devices, trainer_strategy, trainer_precision
237242

238243

239244
def main() -> None:
@@ -286,6 +291,7 @@ def main() -> None:
286291
buffer = ReplayBuffer(capacity=cfg_int("buffer_size"))
287292

288293
hf_checkpointer = init_hf_checkpointer()
294+
ensure_hf_ready(hf_checkpointer)
289295
hf_upload_executor: ThreadPoolExecutor | None = None
290296
hf_upload_futures: list[Future[None]] = []
291297
if hf_checkpointer is not None:
@@ -324,7 +330,7 @@ def main() -> None:
324330
pulse_every=cfg_int("epoch_pulse_every"),
325331
)
326332

327-
_run_warmup_if_needed(
333+
trainer_accelerator, trainer_devices, trainer_strategy, trainer_precision = _run_warmup_if_needed(
328334
start_iteration=start_iteration,
329335
system=system,
330336
buffer=buffer,
@@ -469,18 +475,12 @@ def main() -> None:
469475
message=f"HF checkpoint uploaded for iteration {iteration}.",
470476
)
471477
except (OSError, ValueError):
472-
monitor.log_warning(
473-
iteration=iteration,
474-
message="HF upload failed for this iteration.",
475-
)
478+
monitor.log_warning(iteration=iteration, message="HF upload failed for this iteration.")
476479
if cfg_bool("export_onnx"):
477480
try:
478481
export_onnx(system.model, cfg_str("onnx_path"), device=device)
479482
except (OSError, RuntimeError, ValueError):
480-
monitor.log_warning(
481-
iteration=iteration,
482-
message="ONNX export failed for this iteration.",
483-
)
483+
monitor.log_warning(iteration=iteration, message="ONNX export failed for this iteration.")
484484

485485
cleanup_old_log_versions(
486486
log_dir=log_dir,

0 commit comments

Comments
 (0)