1616src = root / "src"
1717if str (src ) not in sys .path :
1818 sys .path .insert (0 , str (src ))
19-
2019if TYPE_CHECKING :
2120 from data .replay_buffer import ReplayBuffer
2221 from model .system import AtaxxZero
23-
2422from training .bootstrap import generate_imitation_data # noqa: E402
2523from training .callbacks import OptimizerStateTransfer # noqa: E402
2624from training .checkpointing import ( # noqa: E402
@@ -73,7 +71,7 @@ def _build_train_loader(buffer: ReplayBuffer, device: str) -> DataLoader[object]
7371 batch_size = cfg_int ("batch_size" ),
7472 shuffle = True ,
7573 num_workers = cfg_int ("num_workers" ),
76- persistent_workers = True ,
74+ persistent_workers = cfg_bool ( "persistent_workers" ) ,
7775 pin_memory = (device == "cuda" ),
7876 prefetch_factor = 2 ,
7977 )
@@ -99,7 +97,7 @@ def _build_val_loader(buffer: ReplayBuffer, device: str) -> DataLoader[object] |
9997 batch_size = cfg_int ("batch_size" ),
10098 shuffle = False ,
10199 num_workers = cfg_int ("num_workers" ),
102- persistent_workers = True ,
100+ persistent_workers = cfg_bool ( "persistent_workers" ) ,
103101 pin_memory = (device == "cuda" ),
104102 prefetch_factor = 2 ,
105103 )
@@ -494,6 +492,8 @@ def main() -> None:
494492 fail_on_error = cfg_bool ("fail_on_hf_upload_error" ),
495493 )
496494 except Exception as exc :
495+ if cfg_bool ("fail_on_hf_upload_error" ):
496+ raise
497497 log (f"HF upload wait failed: { exc } " )
498498 hf_upload_executor .shutdown (wait = False , cancel_futures = True )
499499if __name__ == "__main__" :
0 commit comments