2626from training .checkpointing import ( # noqa: E402
2727 cleanup_local_checkpoints ,
2828 cleanup_old_log_versions ,
29+ drain_completed_hf_uploads ,
2930 ensure_hf_ready ,
3031 init_hf_checkpointer ,
3132 should_save_iteration_checkpoint ,
@@ -349,6 +350,8 @@ def main() -> None:
349350
350351 try :
351352 for iteration in range (start_iteration + 1 , iterations + 1 ):
353+ if hf_upload_futures :
354+ hf_upload_futures = drain_completed_hf_uploads (hf_upload_futures , fail_on_error = cfg_bool ("fail_on_hf_upload_error" ))
352355 epoch_pulse .set_iteration (iteration )
353356 selfplay_start = time .perf_counter ()
354357 selfplay_stats = execute_self_play (
@@ -358,6 +361,8 @@ def main() -> None:
358361 device = device ,
359362 )
360363 selfplay_s = time .perf_counter () - selfplay_start
364+ if len (buffer ) == 0 :
365+ raise RuntimeError ("Replay buffer is empty after self-play; aborting early." )
361366
362367 train_loader = _build_train_loader (buffer , device = device )
363368 val_loader = _build_val_loader (buffer , device = device )
@@ -409,10 +414,7 @@ def main() -> None:
409414 best_path = checkpoint_dir / "best_eval.ckpt"
410415 trainer .save_checkpoint (str (best_path ))
411416 except Exception as exc :
412- monitor .log_warning (
413- iteration = iteration ,
414- message = f"eval failed, continuing training: { exc } " ,
415- )
417+ monitor .log_warning (iteration = iteration , message = f"eval failed, continuing training: { exc } " )
416418
417419 if not should_save_iteration_checkpoint (
418420 iteration = iteration ,
@@ -430,10 +432,7 @@ def main() -> None:
430432 keep_last_n = cfg_int ("keep_last_n_local_checkpoints" ),
431433 )
432434 except OSError :
433- monitor .log_warning (
434- iteration = iteration ,
435- message = "local checkpoint save failed." ,
436- )
435+ monitor .log_warning (iteration = iteration , message = "local checkpoint save failed." )
437436
438437 if hf_checkpointer is not None :
439438 try :
@@ -458,10 +457,7 @@ def main() -> None:
458457 keep_last_n = cfg_int ("keep_last_n_hf_checkpoints" ),
459458 )
460459 hf_upload_futures .append (future )
461- monitor .log_warning (
462- iteration = iteration ,
463- message = f"HF upload queued for iteration { iteration } ." ,
464- )
460+ monitor .log_warning (iteration = iteration , message = f"HF upload queued for iteration { iteration } ." )
465461 else :
466462 hf_checkpointer .upload_checkpoint_files (
467463 iteration = iteration ,
@@ -470,10 +466,7 @@ def main() -> None:
470466 metadata_path = metadata_path ,
471467 keep_last_n = cfg_int ("keep_last_n_hf_checkpoints" ),
472468 )
473- monitor .log_warning (
474- iteration = iteration ,
475- message = f"HF checkpoint uploaded for iteration { iteration } ." ,
476- )
469+ monitor .log_warning (iteration = iteration , message = f"HF checkpoint uploaded for iteration { iteration } ." )
477470 except (OSError , ValueError ):
478471 monitor .log_warning (iteration = iteration , message = "HF upload failed for this iteration." )
479472 if cfg_bool ("export_onnx" ):
0 commit comments