Skip to content

Commit d2c90ff

Browse files
[feat] main: route checkpoint IO through CheckpointManager
Construct a CheckpointManager rooted at model_dir and use it for both save and load across train/eval/export/predict. Training saves now go through ckpt_manager.save() (which triggers async prune) and close() drains pending deletions before return; restore and latest/best discovery delegate to the manager. fine_tune_checkpoint discovery stays a free-function call since it resolves an external dir outside model_dir. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0273efb commit d2c90ff

1 file changed

Lines changed: 28 additions & 41 deletions

File tree

tzrec/main.py

Lines changed: 28 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def _train_and_evaluate(
323323
model_dir: str,
324324
train_config: TrainConfig,
325325
eval_config: EvalConfig,
326+
ckpt_manager: checkpoint_util.CheckpointManager,
326327
skip_steps: int = -1,
327328
ckpt_path: Optional[str] = None,
328329
eval_result_filename: str = TRAIN_EVAL_RESULT_FILENAME,
@@ -412,7 +413,7 @@ def _train_and_evaluate(
412413
# Restore model and optimizer checkpoint
413414
if i_step == 0 and ckpt_path is not None:
414415
if ignore_restore_optimizer:
415-
checkpoint_util.restore_model(
416+
ckpt_manager.restore(
416417
ckpt_path, model, None, train_config.fine_tune_ckpt_param_map
417418
)
418419
else:
@@ -421,7 +422,7 @@ def _train_and_evaluate(
421422
peek_batch = next(train_iterator)
422423
pipeline.progress(iter([peek_batch]))
423424
train_iterator = itertools.chain([peek_batch], train_iterator)
424-
checkpoint_util.restore_model(
425+
ckpt_manager.restore(
425426
ckpt_path, model, optimizer, train_config.fine_tune_ckpt_param_map
426427
)
427428

@@ -459,13 +460,7 @@ def _train_and_evaluate(
459460
if save_checkpoints_steps > 0 and i_step > 0:
460461
if i_step % save_checkpoints_steps == 0:
461462
last_ckpt_step = i_step
462-
ckpt_dir = os.path.join(model_dir, f"model.ckpt-{i_step}")
463-
checkpoint_util.save_model(
464-
ckpt_dir,
465-
model,
466-
optimizer,
467-
)
468-
checkpoint_util.save_dataloader_state(ckpt_dir, dataloader_state)
463+
ckpt_manager.save(i_step, model, optimizer, dataloader_state)
469464
if eval_dataloader is not None:
470465
_evaluate(
471466
model,
@@ -484,13 +479,7 @@ def _train_and_evaluate(
484479
if save_checkpoints_epochs > 0 and i_step > 0:
485480
if (i_epoch + 1) % save_checkpoints_epochs == 0:
486481
last_ckpt_step = i_step
487-
ckpt_dir = os.path.join(model_dir, f"model.ckpt-{i_step}")
488-
checkpoint_util.save_model(
489-
ckpt_dir,
490-
model,
491-
optimizer,
492-
)
493-
checkpoint_util.save_dataloader_state(ckpt_dir, dataloader_state)
482+
ckpt_manager.save(i_step, model, optimizer, dataloader_state)
494483
if eval_dataloader is not None:
495484
_evaluate(
496485
model,
@@ -525,13 +514,7 @@ def _train_and_evaluate(
525514
if train_config.is_profiling:
526515
prof.stop()
527516
if last_ckpt_step != i_step:
528-
ckpt_dir = os.path.join(model_dir, f"model.ckpt-{i_step}")
529-
checkpoint_util.save_model(
530-
ckpt_dir,
531-
model,
532-
optimizer,
533-
)
534-
checkpoint_util.save_dataloader_state(ckpt_dir, dataloader_state)
517+
ckpt_manager.save(i_step, model, optimizer, dataloader_state)
535518
if eval_dataloader is not None:
536519
_evaluate(
537520
model,
@@ -544,6 +527,7 @@ def _train_and_evaluate(
544527
check_all_workers_data_status=check_all_workers_data_status,
545528
)
546529
model.train()
530+
ckpt_manager.close()
547531

548532

549533
def train_and_evaluate(
@@ -610,10 +594,17 @@ def train_and_evaluate(
610594
gl_cluster=gl_cluster,
611595
)
612596

597+
ckpt_manager = checkpoint_util.CheckpointManager(
598+
pipeline_config.model_dir,
599+
keep_checkpoint_max=train_config.keep_checkpoint_max,
600+
export_config=pipeline_config.export_config,
601+
)
602+
613603
# Get Restore Ckpt Path
614604
ckpt_path = None
615605
skip_steps = -1
616606
if pipeline_config.train_config.fine_tune_checkpoint:
607+
# fine_tune_checkpoint is an external dir, outside the manager's model_dir.
617608
ckpt_path, _ = checkpoint_util.latest_checkpoint(
618609
pipeline_config.train_config.fine_tune_checkpoint
619610
)
@@ -624,9 +615,7 @@ def train_and_evaluate(
624615
)
625616
if os.path.exists(pipeline_config.model_dir):
626617
# Restore dataloader state if continuing training
627-
latest_ckpt_path, skip_steps = checkpoint_util.latest_checkpoint(
628-
pipeline_config.model_dir
629-
)
618+
latest_ckpt_path, skip_steps = ckpt_manager.latest_checkpoint()
630619
if latest_ckpt_path:
631620
if continue_train:
632621
ckpt_path = latest_ckpt_path
@@ -641,7 +630,7 @@ def train_and_evaluate(
641630
# Restore dataloader checkpoint state
642631
dataloader_state: Optional[Dict[str, int]] = None
643632
if ckpt_path and continue_train:
644-
dataloader_state = checkpoint_util.restore_dataloader_state(ckpt_path)
633+
dataloader_state = ckpt_manager.restore_dataloader_state(ckpt_path)
645634
if dataloader_state:
646635
train_dataloader.dataset.load_state_dict(dataloader_state)
647636

@@ -778,6 +767,7 @@ def train_and_evaluate(
778767
pipeline_config.model_dir,
779768
train_config=train_config,
780769
eval_config=pipeline_config.eval_config,
770+
ckpt_manager=ckpt_manager,
781771
skip_steps=skip_steps,
782772
ckpt_path=ckpt_path,
783773
check_all_workers_data_status=check_all_workers_data_status,
@@ -837,11 +827,10 @@ def evaluate(
837827
model, device=device, mixed_precision=train_config.mixed_precision
838828
)
839829

830+
ckpt_manager = checkpoint_util.CheckpointManager(pipeline_config.model_dir)
840831
global_step = None
841832
if not checkpoint_path:
842-
checkpoint_path, global_step = checkpoint_util.latest_checkpoint(
843-
pipeline_config.model_dir
844-
)
833+
checkpoint_path, global_step = ckpt_manager.latest_checkpoint()
845834
planner = create_planner(
846835
device=device,
847836
# pyre-ignore [16]
@@ -864,7 +853,7 @@ def evaluate(
864853
)
865854

866855
if checkpoint_path:
867-
checkpoint_util.restore_model(checkpoint_path, model)
856+
ckpt_manager.restore(checkpoint_path, model)
868857
else:
869858
raise ValueError("Eval checkpoint path should be specified.")
870859

@@ -947,17 +936,16 @@ def export(
947936
model = InferWrapper(model)
948937

949938
if not checkpoint_path:
939+
ckpt_manager = checkpoint_util.CheckpointManager(
940+
pipeline_config.model_dir, export_config=pipeline_config.export_config
941+
)
950942
if (
951943
pipeline_config.HasField("export_config")
952944
and pipeline_config.export_config.exporter_type == "best"
953945
):
954-
checkpoint_path, _ = checkpoint_util.best_checkpoint(
955-
pipeline_config.model_dir, pipeline_config.export_config
956-
)
946+
checkpoint_path, _ = ckpt_manager.best_checkpoint()
957947
else:
958-
checkpoint_path, _ = checkpoint_util.latest_checkpoint(
959-
pipeline_config.model_dir
960-
)
948+
checkpoint_path, _ = ckpt_manager.latest_checkpoint()
961949

962950
if isinstance(model.model, MatchModel):
963951
for name, module in model.model.named_children():
@@ -1386,11 +1374,10 @@ def predict_checkpoint(
13861374
output_cols=output_cols,
13871375
)
13881376

1377+
ckpt_manager = checkpoint_util.CheckpointManager(pipeline_config.model_dir)
13891378
global_step = None
13901379
if not checkpoint_path:
1391-
checkpoint_path, global_step = checkpoint_util.latest_checkpoint(
1392-
pipeline_config.model_dir
1393-
)
1380+
checkpoint_path, global_step = ckpt_manager.latest_checkpoint()
13941381
planner = create_planner(
13951382
device=device,
13961383
# pyre-ignore [16]
@@ -1413,7 +1400,7 @@ def predict_checkpoint(
14131400
model.eval()
14141401

14151402
if checkpoint_path:
1416-
checkpoint_util.restore_model(checkpoint_path, model)
1403+
ckpt_manager.restore(checkpoint_path, model)
14171404
else:
14181405
raise ValueError("Predict checkpoint path should be specified.")
14191406

0 commit comments

Comments
 (0)