Skip to content

Commit d4aeeea

Browse files
committed
training: add HF bootstrap mode with fresh-iteration reset
1 parent 4a50e32 commit d4aeeea

5 files changed

Lines changed: 175 additions & 25 deletions

File tree

src/training/checkpointing.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ def __init__(
7070
self.api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
7171

7272
def _repo_path(self, filename: str) -> str:
73-
return f"runs/{self.run_id}/{filename}"
73+
return self._repo_path_for_run(self.run_id, filename)
74+
75+
@staticmethod
76+
def _repo_path_for_run(run_id: str, filename: str) -> str:
77+
return f"runs/{run_id}/{filename}"
7478

7579
def save_checkpoint_local(
7680
self,
@@ -157,12 +161,23 @@ def upload_checkpoint_files(
157161
)
158162
self.cleanup_local_checkpoints(keep_last_n=keep_last_n)
159163

160-
def load_latest_checkpoint(self, *, system: AtaxxZero, buffer: ReplayBuffer) -> int:
164+
def load_latest_checkpoint(
165+
self,
166+
*,
167+
system: AtaxxZero,
168+
buffer: ReplayBuffer,
169+
run_id: str | None = None,
170+
load_buffer: bool = True,
171+
) -> int:
161172
hub_mod = __import__("huggingface_hub", fromlist=["hf_hub_download"])
162173
hf_hub_download = hub_mod.hf_hub_download
163174

175+
source_run_id = (run_id or self.run_id).strip()
176+
if source_run_id == "":
177+
raise ValueError("Checkpoint source run_id cannot be empty.")
178+
164179
files = self.api.list_repo_files(repo_id=self.repo_id, repo_type="model")
165-
run_prefix = self._repo_path("")
180+
run_prefix = self._repo_path_for_run(source_run_id, "")
166181
model_files = [
167182
f
168183
for f in files
@@ -175,7 +190,7 @@ def load_latest_checkpoint(self, *, system: AtaxxZero, buffer: ReplayBuffer) ->
175190

176191
latest_iter = max(int(Path(name).stem.split("_")[2]) for name in model_files)
177192
model_name = f"model_iter_{latest_iter:03d}.pt"
178-
model_repo_path = self._repo_path(model_name)
193+
model_repo_path = self._repo_path_for_run(source_run_id, model_name)
179194
model_path = hf_hub_download(
180195
repo_id=self.repo_id,
181196
filename=model_repo_path,
@@ -197,25 +212,26 @@ def load_latest_checkpoint(self, *, system: AtaxxZero, buffer: ReplayBuffer) ->
197212
"reentrena o usa carga parcial manual (strict=False)."
198213
) from exc
199214

200-
buffer_name = f"buffer_iter_{latest_iter:03d}.npz"
201-
buffer_repo_path = self._repo_path(buffer_name)
202-
try:
203-
buffer_path = hf_hub_download(
204-
repo_id=self.repo_id,
205-
filename=buffer_repo_path,
206-
repo_type="model",
207-
token=self.token,
208-
local_dir=str(self.local_dir),
209-
)
210-
data = np.load(buffer_path)
211-
observations = data["observations"]
212-
policies = data["policies"]
213-
values = data["values"]
214-
examples = list(zip(observations, policies, values, strict=True))
215-
buffer.clear()
216-
buffer.save_game(examples)
217-
except (OSError, KeyError, ValueError):
218-
pass
215+
if load_buffer:
216+
buffer_name = f"buffer_iter_{latest_iter:03d}.npz"
217+
buffer_repo_path = self._repo_path_for_run(source_run_id, buffer_name)
218+
try:
219+
buffer_path = hf_hub_download(
220+
repo_id=self.repo_id,
221+
filename=buffer_repo_path,
222+
repo_type="model",
223+
token=self.token,
224+
local_dir=str(self.local_dir),
225+
)
226+
data = np.load(buffer_path)
227+
observations = data["observations"]
228+
policies = data["policies"]
229+
values = data["values"]
230+
examples = list(zip(observations, policies, values, strict=True))
231+
buffer.clear()
232+
buffer.save_game(examples)
233+
except (OSError, KeyError, ValueError):
234+
pass
219235

220236
return latest_iter
221237

src/training/config_runtime.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
"hf_enabled": False,
5757
"hf_repo_id": "",
5858
"hf_run_id": "policy_spatial_v1",
59+
"hf_bootstrap_run_id": "",
60+
"hf_reset_iteration": False,
5961
"hf_token_env": "HF_TOKEN",
6062
"hf_local_dir": "hf_checkpoints",
6163
"max_pending_hf_uploads": 2,
@@ -176,6 +178,8 @@ def parse_args() -> argparse.Namespace:
176178
parser.add_argument("--hf", action="store_true")
177179
parser.add_argument("--hf-repo-id", default=None)
178180
parser.add_argument("--hf-run-id", default=None)
181+
parser.add_argument("--hf-bootstrap-run-id", default=None)
182+
parser.add_argument("--hf-reset-iteration", action="store_true")
179183
parser.add_argument("--max-pending-hf-uploads", type=int, default=None)
180184
parser.add_argument("--hf-upload-timeout-s", type=float, default=None)
181185
return parser.parse_args()
@@ -297,6 +301,10 @@ def apply_cli_overrides(args: argparse.Namespace) -> None:
297301
CONFIG["hf_repo_id"] = args.hf_repo_id
298302
if args.hf_run_id is not None:
299303
CONFIG["hf_run_id"] = args.hf_run_id.strip()
304+
if args.hf_bootstrap_run_id is not None:
305+
CONFIG["hf_bootstrap_run_id"] = args.hf_bootstrap_run_id.strip()
306+
if args.hf_reset_iteration:
307+
CONFIG["hf_reset_iteration"] = True
300308
if args.max_pending_hf_uploads is not None:
301309
CONFIG["max_pending_hf_uploads"] = max(1, args.max_pending_hf_uploads)
302310
if args.hf_upload_timeout_s is not None:

tests/test_training_checkpointing.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
from __future__ import annotations
22

3+
import sys
4+
import types
35
import unittest
46
from concurrent.futures import Future
7+
from pathlib import Path
8+
from typing import Any, cast
9+
from unittest.mock import MagicMock, Mock, patch
510

611
from training.checkpointing import (
712
HuggingFaceCheckpointer,
@@ -27,6 +32,57 @@ def test_repo_path_is_namespaced_by_run_id(self) -> None:
2732
repo_path = checkpointer._repo_path("model_iter_040.pt")
2833
self.assertEqual(repo_path, "runs/policy_spatial_v1/model_iter_040.pt")
2934

35+
def test_repo_path_for_run_allows_explicit_source_namespace(self) -> None:
36+
repo_path = HuggingFaceCheckpointer._repo_path_for_run(
37+
run_id="policy_spatial_v2",
38+
filename="model_iter_001.pt",
39+
)
40+
self.assertEqual(repo_path, "runs/policy_spatial_v2/model_iter_001.pt")
41+
42+
def test_load_latest_checkpoint_can_bootstrap_from_explicit_run_without_buffer(self) -> None:
43+
sample_value = "sample_value"
44+
checkpointer = object.__new__(HuggingFaceCheckpointer)
45+
checkpointer.repo_id = "dieg0code/ataxx-zero"
46+
checkpointer.token = sample_value
47+
checkpointer.run_id = "policy_target_v2"
48+
checkpointer.local_dir = Path()
49+
checkpointer.api = Mock()
50+
checkpointer.api.list_repo_files.return_value = [
51+
"runs/policy_source_v1/model_iter_022.pt",
52+
"runs/policy_source_v1/buffer_iter_022.npz",
53+
"runs/policy_target_v2/model_iter_010.pt",
54+
]
55+
56+
hf_download_mock = MagicMock(return_value="model_iter_022.pt")
57+
hub_module = cast(Any, types.ModuleType("huggingface_hub"))
58+
hub_module.hf_hub_download = hf_download_mock
59+
60+
system = Mock()
61+
buffer = Mock()
62+
63+
with patch.dict(sys.modules, {"huggingface_hub": hub_module}), patch(
64+
"training.checkpointing.torch.load"
65+
) as torch_load_mock:
66+
torch_load_mock.return_value = {"state_dict": {}}
67+
loaded_iter = checkpointer.load_latest_checkpoint(
68+
system=system,
69+
buffer=buffer,
70+
run_id="policy_source_v1",
71+
load_buffer=False,
72+
)
73+
74+
self.assertEqual(loaded_iter, 22)
75+
hf_download_mock.assert_called_once_with(
76+
repo_id="dieg0code/ataxx-zero",
77+
filename="runs/policy_source_v1/model_iter_022.pt",
78+
repo_type="model",
79+
token=sample_value,
80+
local_dir=".",
81+
)
82+
system.load_state_dict.assert_called_once_with({})
83+
buffer.clear.assert_not_called()
84+
buffer.save_game.assert_not_called()
85+
3086
def test_ensure_hf_ready_raises_when_hf_enabled_without_checkpointer(self) -> None:
3187
CONFIG["hf_enabled"] = True
3288
CONFIG["hf_token_env"] = "HF_TOKEN" # noqa: S105 - test fixture value, not a secret.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
import unittest
5+
from unittest.mock import patch
6+
7+
from training.config_runtime import CONFIG, apply_cli_overrides, parse_args
8+
9+
10+
class TestTrainingConfigRuntime(unittest.TestCase):
11+
def setUp(self) -> None:
12+
self._backup = dict(CONFIG)
13+
14+
def tearDown(self) -> None:
15+
CONFIG.clear()
16+
CONFIG.update(self._backup)
17+
18+
def test_hf_bootstrap_flags_are_applied_from_cli(self) -> None:
19+
with patch.object(
20+
sys,
21+
"argv",
22+
[
23+
"train.py",
24+
"--hf",
25+
"--hf-run-id",
26+
"policy_target_v2",
27+
"--hf-bootstrap-run-id",
28+
"policy_source_v1",
29+
"--hf-reset-iteration",
30+
],
31+
):
32+
args = parse_args()
33+
34+
apply_cli_overrides(args)
35+
36+
self.assertTrue(bool(CONFIG["hf_enabled"]))
37+
self.assertEqual(str(CONFIG["hf_run_id"]), "policy_target_v2")
38+
self.assertEqual(str(CONFIG["hf_bootstrap_run_id"]), "policy_source_v1")
39+
self.assertTrue(bool(CONFIG["hf_reset_iteration"]))
40+
41+
42+
if __name__ == "__main__":
43+
unittest.main()

train.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,39 @@ def main() -> None:
296296
hf_upload_futures: list[Future[None]] = []
297297
if hf_checkpointer is not None:
298298
hf_upload_executor = ThreadPoolExecutor(max_workers=1)
299+
bootstrap_run_id = cfg_str("hf_bootstrap_run_id").strip()
300+
source_run_id = bootstrap_run_id or cfg_str("hf_run_id").strip()
301+
reset_iteration = cfg_bool("hf_reset_iteration")
299302
try:
300-
start_iteration = hf_checkpointer.load_latest_checkpoint(
303+
loaded_iteration = hf_checkpointer.load_latest_checkpoint(
301304
system=system,
302305
buffer=buffer,
306+
run_id=(bootstrap_run_id or None),
307+
load_buffer=not reset_iteration,
303308
)
304-
log(f"Resumed from HF checkpoint iteration {start_iteration}.")
309+
if loaded_iteration > 0:
310+
if reset_iteration:
311+
# Fresh-run bootstrap: keep learned weights but rebuild replay
312+
# from scratch so warmup/curriculum can run from iteration 0.
313+
buffer.clear()
314+
start_iteration = 0
315+
log(
316+
"HF bootstrap loaded "
317+
f"iteration {loaded_iteration} from run_id={source_run_id}; "
318+
"resetting iteration to 0 and clearing replay buffer.",
319+
)
320+
else:
321+
start_iteration = loaded_iteration
322+
if bootstrap_run_id != "":
323+
log(
324+
"Resumed from HF checkpoint iteration "
325+
f"{start_iteration} (source run_id={source_run_id}).",
326+
)
327+
else:
328+
log(f"Resumed from HF checkpoint iteration {start_iteration}.")
329+
else:
330+
start_iteration = 0
331+
log(f"No HF checkpoint found in run_id={source_run_id}; starting from scratch.")
305332
except (ValueError, OSError):
306333
start_iteration = 0
307334
log("HF resume failed; starting from scratch.")

0 commit comments

Comments
 (0)