@@ -70,7 +70,11 @@ def __init__(
7070 self .api .create_repo (repo_id = repo_id , repo_type = "model" , exist_ok = True )
7171
7272 def _repo_path (self , filename : str ) -> str :
73- return f"runs/{ self .run_id } /{ filename } "
73+ return self ._repo_path_for_run (self .run_id , filename )
74+
75+ @staticmethod
76+ def _repo_path_for_run (run_id : str , filename : str ) -> str :
77+ return f"runs/{ run_id } /{ filename } "
7478
7579 def save_checkpoint_local (
7680 self ,
@@ -157,12 +161,23 @@ def upload_checkpoint_files(
157161 )
158162 self .cleanup_local_checkpoints (keep_last_n = keep_last_n )
159163
160- def load_latest_checkpoint (self , * , system : AtaxxZero , buffer : ReplayBuffer ) -> int :
164+ def load_latest_checkpoint (
165+ self ,
166+ * ,
167+ system : AtaxxZero ,
168+ buffer : ReplayBuffer ,
169+ run_id : str | None = None ,
170+ load_buffer : bool = True ,
171+ ) -> int :
161172 hub_mod = __import__ ("huggingface_hub" , fromlist = ["hf_hub_download" ])
162173 hf_hub_download = hub_mod .hf_hub_download
163174
175+ source_run_id = (run_id or self .run_id ).strip ()
176+ if source_run_id == "" :
177+ raise ValueError ("Checkpoint source run_id cannot be empty." )
178+
164179 files = self .api .list_repo_files (repo_id = self .repo_id , repo_type = "model" )
165- run_prefix = self ._repo_path ( "" )
180+ run_prefix = self ._repo_path_for_run ( source_run_id , "" )
166181 model_files = [
167182 f
168183 for f in files
@@ -175,7 +190,7 @@ def load_latest_checkpoint(self, *, system: AtaxxZero, buffer: ReplayBuffer) ->
175190
176191 latest_iter = max (int (Path (name ).stem .split ("_" )[2 ]) for name in model_files )
177192 model_name = f"model_iter_{ latest_iter :03d} .pt"
178- model_repo_path = self ._repo_path ( model_name )
193+ model_repo_path = self ._repo_path_for_run ( source_run_id , model_name )
179194 model_path = hf_hub_download (
180195 repo_id = self .repo_id ,
181196 filename = model_repo_path ,
@@ -197,25 +212,26 @@ def load_latest_checkpoint(self, *, system: AtaxxZero, buffer: ReplayBuffer) ->
197212 "reentrena o usa carga parcial manual (strict=False)."
198213 ) from exc
199214
200- buffer_name = f"buffer_iter_{ latest_iter :03d} .npz"
201- buffer_repo_path = self ._repo_path (buffer_name )
202- try :
203- buffer_path = hf_hub_download (
204- repo_id = self .repo_id ,
205- filename = buffer_repo_path ,
206- repo_type = "model" ,
207- token = self .token ,
208- local_dir = str (self .local_dir ),
209- )
210- data = np .load (buffer_path )
211- observations = data ["observations" ]
212- policies = data ["policies" ]
213- values = data ["values" ]
214- examples = list (zip (observations , policies , values , strict = True ))
215- buffer .clear ()
216- buffer .save_game (examples )
217- except (OSError , KeyError , ValueError ):
218- pass
215+ if load_buffer :
216+ buffer_name = f"buffer_iter_{ latest_iter :03d} .npz"
217+ buffer_repo_path = self ._repo_path_for_run (source_run_id , buffer_name )
218+ try :
219+ buffer_path = hf_hub_download (
220+ repo_id = self .repo_id ,
221+ filename = buffer_repo_path ,
222+ repo_type = "model" ,
223+ token = self .token ,
224+ local_dir = str (self .local_dir ),
225+ )
226+ data = np .load (buffer_path )
227+ observations = data ["observations" ]
228+ policies = data ["policies" ]
229+ values = data ["values" ]
230+ examples = list (zip (observations , policies , values , strict = True ))
231+ buffer .clear ()
232+ buffer .save_game (examples )
233+ except (OSError , KeyError , ValueError ):
234+ pass
219235
220236 return latest_iter
221237
0 commit comments