2626from training .checkpointing import ( # noqa: E402
2727 cleanup_local_checkpoints ,
2828 cleanup_old_log_versions ,
29+ ensure_hf_ready ,
2930 init_hf_checkpointer ,
3031 should_save_iteration_checkpoint ,
3132)
@@ -198,11 +199,11 @@ def _run_warmup_if_needed(
198199 optimizer_transfer : OptimizerStateTransfer ,
199200 monitor : TrainingMonitor ,
200201 epoch_pulse : EpochPulseCallback ,
201- ) -> None :
202+ ) -> tuple [ str , int , str , TrainerPrecision ] :
202203 warmup_games = cfg_int ("warmup_games" )
203204 warmup_epochs = cfg_int ("warmup_epochs" )
204205 if start_iteration != 0 or warmup_games <= 0 or warmup_epochs <= 0 :
205- return
206+ return trainer_accelerator , trainer_devices , trainer_strategy , trainer_precision
206207
207208 # Warmup seeds the policy with legal, sensible moves before self-play noise.
208209 warmup_rng = torch .Generator ().manual_seed (cfg_int ("seed" ))
@@ -216,24 +217,28 @@ def _run_warmup_if_needed(
216217 monitor .log_warmup (examples = len (warmup_examples ), games = warmup_games )
217218 train_loader = _build_train_loader (buffer , device = device )
218219 val_loader = _build_val_loader (buffer , device = device )
219- warmup_trainer = build_trainer (
220+ (
221+ _warmup_trainer ,
222+ trainer_accelerator ,
223+ trainer_devices ,
224+ trainer_strategy ,
225+ trainer_precision ,
226+ ) = _fit_with_ddp_fallback (
227+ system = system ,
228+ train_loader = train_loader ,
229+ val_loader = val_loader ,
220230 epochs = warmup_epochs ,
221- accelerator = trainer_accelerator ,
222- devices = trainer_devices ,
223- strategy = trainer_strategy ,
224- precision = trainer_precision ,
225- benchmark = cfg_bool ("trainer_benchmark" ),
231+ trainer_accelerator = trainer_accelerator ,
232+ trainer_devices = trainer_devices ,
233+ trainer_strategy = trainer_strategy ,
234+ trainer_precision = trainer_precision ,
226235 checkpoint_callback = checkpoint_callback ,
227236 lr_monitor = lr_monitor ,
228237 logger = logger ,
229- extra_callbacks = [optimizer_transfer , epoch_pulse ],
230- )
231- system .train ()
232- warmup_trainer .fit (
233- model = system ,
234- train_dataloaders = train_loader ,
235- val_dataloaders = val_loader ,
238+ optimizer_transfer = optimizer_transfer ,
239+ epoch_pulse = epoch_pulse ,
236240 )
241+ return trainer_accelerator , trainer_devices , trainer_strategy , trainer_precision
237242
238243
239244def main () -> None :
@@ -286,6 +291,7 @@ def main() -> None:
286291 buffer = ReplayBuffer (capacity = cfg_int ("buffer_size" ))
287292
288293 hf_checkpointer = init_hf_checkpointer ()
294+ ensure_hf_ready (hf_checkpointer )
289295 hf_upload_executor : ThreadPoolExecutor | None = None
290296 hf_upload_futures : list [Future [None ]] = []
291297 if hf_checkpointer is not None :
@@ -324,7 +330,7 @@ def main() -> None:
324330 pulse_every = cfg_int ("epoch_pulse_every" ),
325331 )
326332
327- _run_warmup_if_needed (
333+ trainer_accelerator , trainer_devices , trainer_strategy , trainer_precision = _run_warmup_if_needed (
328334 start_iteration = start_iteration ,
329335 system = system ,
330336 buffer = buffer ,
@@ -469,18 +475,12 @@ def main() -> None:
469475 message = f"HF checkpoint uploaded for iteration { iteration } ." ,
470476 )
471477 except (OSError , ValueError ):
472- monitor .log_warning (
473- iteration = iteration ,
474- message = "HF upload failed for this iteration." ,
475- )
478+ monitor .log_warning (iteration = iteration , message = "HF upload failed for this iteration." )
476479 if cfg_bool ("export_onnx" ):
477480 try :
478481 export_onnx (system .model , cfg_str ("onnx_path" ), device = device )
479482 except (OSError , RuntimeError , ValueError ):
480- monitor .log_warning (
481- iteration = iteration ,
482- message = "ONNX export failed for this iteration." ,
483- )
483+ monitor .log_warning (iteration = iteration , message = "ONNX export failed for this iteration." )
484484
485485 cleanup_old_log_versions (
486486 log_dir = log_dir ,
0 commit comments