Skip to content

Commit 12fcfcf

Browse files
committed
Add HF upload timeout and backlog fail-fast guards
1 parent a77ace2 commit 12fcfcf

4 files changed

Lines changed: 61 additions & 7 deletions

File tree

src/training/checkpointing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import os
55
from concurrent.futures import Future
6+
from concurrent.futures import TimeoutError as FutureTimeoutError
67
from datetime import datetime, timezone
78
from pathlib import Path
89
from typing import TYPE_CHECKING, Any
@@ -296,6 +297,29 @@ def drain_completed_hf_uploads(
296297
return pending
297298

298299

300+
def wait_for_hf_uploads(
301+
futures: list[Future[None]],
302+
*,
303+
timeout_s: float,
304+
fail_on_error: bool,
305+
) -> None:
306+
"""Wait for pending HF uploads with a bounded timeout per future."""
307+
bounded_timeout = max(1.0, float(timeout_s))
308+
for future in futures:
309+
try:
310+
future.result(timeout=bounded_timeout)
311+
except FutureTimeoutError as exc:
312+
if fail_on_error:
313+
raise RuntimeError(
314+
f"HF upload timed out after {bounded_timeout:.1f}s per future.",
315+
) from exc
316+
log(f"HF upload timed out (continuing): {exc}")
317+
except Exception as exc:
318+
if fail_on_error:
319+
raise RuntimeError("HF upload future failed during shutdown wait.") from exc
320+
log(f"HF upload failed during shutdown (continuing): {exc}")
321+
322+
299323
def should_save_iteration_checkpoint(iteration: int, total_iterations: int, save_every: int) -> bool:
300324
"""Always persist the final iteration even when it is not divisible by save_every."""
301325
if iteration >= total_iterations:
@@ -311,4 +335,5 @@ def should_save_iteration_checkpoint(iteration: int, total_iterations: int, save
311335
"ensure_hf_ready",
312336
"init_hf_checkpointer",
313337
"should_save_iteration_checkpoint",
338+
"wait_for_hf_uploads",
314339
]

src/training/config_runtime.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858
"hf_run_id": "policy_spatial_v1",
5959
"hf_token_env": "HF_TOKEN",
6060
"hf_local_dir": "hf_checkpoints",
61+
"max_pending_hf_uploads": 2,
62+
"hf_upload_future_timeout_s": 120.0,
6163
"show_progress_bar": False,
6264
"trainer_log_every_n_steps": 99_999,
6365
"monitor_log_every": 5,
@@ -170,6 +172,8 @@ def parse_args() -> argparse.Namespace:
170172
parser.add_argument("--hf", action="store_true")
171173
parser.add_argument("--hf-repo-id", default=None)
172174
parser.add_argument("--hf-run-id", default=None)
175+
parser.add_argument("--max-pending-hf-uploads", type=int, default=None)
176+
parser.add_argument("--hf-upload-timeout-s", type=float, default=None)
173177
return parser.parse_args()
174178

175179

@@ -285,6 +289,10 @@ def apply_cli_overrides(args: argparse.Namespace) -> None:
285289
CONFIG["hf_repo_id"] = args.hf_repo_id
286290
if args.hf_run_id is not None:
287291
CONFIG["hf_run_id"] = args.hf_run_id.strip()
292+
if args.max_pending_hf_uploads is not None:
293+
CONFIG["max_pending_hf_uploads"] = max(1, args.max_pending_hf_uploads)
294+
if args.hf_upload_timeout_s is not None:
295+
CONFIG["hf_upload_future_timeout_s"] = max(1.0, args.hf_upload_timeout_s)
288296

289297

290298
def cfg_int(key: str) -> int:
@@ -344,6 +352,10 @@ def validate_config() -> None:
344352
raise ValueError("CONFIG['mcts_cache_size'] must be >= 0.")
345353
if cfg_int("ddp_timeout_seconds") <= 0:
346354
raise ValueError("CONFIG['ddp_timeout_seconds'] must be > 0.")
355+
if cfg_int("max_pending_hf_uploads") <= 0:
356+
raise ValueError("CONFIG['max_pending_hf_uploads'] must be > 0.")
357+
if cfg_float("hf_upload_future_timeout_s") <= 0.0:
358+
raise ValueError("CONFIG['hf_upload_future_timeout_s'] must be > 0.")
347359

348360
opp_sum = (
349361
cfg_float("opponent_self_prob")

tests/test_training_checkpointing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
drain_completed_hf_uploads,
99
ensure_hf_ready,
1010
should_save_iteration_checkpoint,
11+
wait_for_hf_uploads,
1112
)
1213
from training.config_runtime import CONFIG
1314

@@ -55,6 +56,15 @@ def test_drain_completed_hf_uploads_continues_when_fail_fast_disabled(self) -> N
5556
remaining = drain_completed_hf_uploads([failed], fail_on_error=False)
5657
self.assertEqual(remaining, [])
5758

59+
def test_wait_for_hf_uploads_raises_timeout_in_fail_fast_mode(self) -> None:
60+
pending: Future[None] = Future()
61+
with self.assertRaises(RuntimeError):
62+
wait_for_hf_uploads([pending], timeout_s=1.0, fail_on_error=True)
63+
64+
def test_wait_for_hf_uploads_tolerates_timeout_when_configured(self) -> None:
65+
pending: Future[None] = Future()
66+
wait_for_hf_uploads([pending], timeout_s=1.0, fail_on_error=False)
67+
5868
def test_should_save_iteration_checkpoint_on_schedule(self) -> None:
5969
self.assertTrue(
6070
should_save_iteration_checkpoint(

train.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ensure_hf_ready,
3131
init_hf_checkpointer,
3232
should_save_iteration_checkpoint,
33+
wait_for_hf_uploads,
3334
)
3435
from training.config_runtime import ( # noqa: E402
3536
CONFIG,
@@ -352,6 +353,8 @@ def main() -> None:
352353
for iteration in range(start_iteration + 1, iterations + 1):
353354
if hf_upload_futures:
354355
hf_upload_futures = drain_completed_hf_uploads(hf_upload_futures, fail_on_error=cfg_bool("fail_on_hf_upload_error"))
356+
if len(hf_upload_futures) > cfg_int("max_pending_hf_uploads"):
357+
raise RuntimeError("HF upload backlog is growing; aborting early.")
355358
epoch_pulse.set_iteration(iteration)
356359
selfplay_start = time.perf_counter()
357360
selfplay_stats = execute_self_play(
@@ -458,6 +461,8 @@ def main() -> None:
458461
)
459462
hf_upload_futures.append(future)
460463
monitor.log_warning(iteration=iteration, message=f"HF upload queued for iteration {iteration}.")
464+
if len(hf_upload_futures) > cfg_int("max_pending_hf_uploads"):
465+
raise RuntimeError("HF upload backlog exceeded configured threshold.")
461466
else:
462467
hf_checkpointer.upload_checkpoint_files(
463468
iteration=iteration,
@@ -482,12 +487,14 @@ def main() -> None:
482487
)
483488
finally:
484489
if hf_upload_executor is not None:
485-
for future in hf_upload_futures:
486-
try:
487-
future.result()
488-
except Exception:
489-
log("A queued HF upload failed.")
490-
hf_upload_executor.shutdown(wait=True)
491-
490+
try:
491+
wait_for_hf_uploads(
492+
hf_upload_futures,
493+
timeout_s=cfg_float("hf_upload_future_timeout_s"),
494+
fail_on_error=cfg_bool("fail_on_hf_upload_error"),
495+
)
496+
except Exception as exc:
497+
log(f"HF upload wait failed: {exc}")
498+
hf_upload_executor.shutdown(wait=False, cancel_futures=True)
492499
if __name__ == "__main__":
493500
main()

0 commit comments

Comments
 (0)