@@ -323,6 +323,7 @@ def _train_and_evaluate(
323323 model_dir : str ,
324324 train_config : TrainConfig ,
325325 eval_config : EvalConfig ,
326+ ckpt_manager : checkpoint_util .CheckpointManager ,
326327 skip_steps : int = - 1 ,
327328 ckpt_path : Optional [str ] = None ,
328329 eval_result_filename : str = TRAIN_EVAL_RESULT_FILENAME ,
@@ -412,7 +413,7 @@ def _train_and_evaluate(
412413 # Restore model and optimizer checkpoint
413414 if i_step == 0 and ckpt_path is not None :
414415 if ignore_restore_optimizer :
415- checkpoint_util . restore_model (
416+ ckpt_manager . restore (
416417 ckpt_path , model , None , train_config .fine_tune_ckpt_param_map
417418 )
418419 else :
@@ -421,7 +422,7 @@ def _train_and_evaluate(
421422 peek_batch = next (train_iterator )
422423 pipeline .progress (iter ([peek_batch ]))
423424 train_iterator = itertools .chain ([peek_batch ], train_iterator )
424- checkpoint_util . restore_model (
425+ ckpt_manager . restore (
425426 ckpt_path , model , optimizer , train_config .fine_tune_ckpt_param_map
426427 )
427428
@@ -459,13 +460,7 @@ def _train_and_evaluate(
459460 if save_checkpoints_steps > 0 and i_step > 0 :
460461 if i_step % save_checkpoints_steps == 0 :
461462 last_ckpt_step = i_step
462- ckpt_dir = os .path .join (model_dir , f"model.ckpt-{ i_step } " )
463- checkpoint_util .save_model (
464- ckpt_dir ,
465- model ,
466- optimizer ,
467- )
468- checkpoint_util .save_dataloader_state (ckpt_dir , dataloader_state )
463+ ckpt_manager .save (i_step , model , optimizer , dataloader_state )
469464 if eval_dataloader is not None :
470465 _evaluate (
471466 model ,
@@ -484,13 +479,7 @@ def _train_and_evaluate(
484479 if save_checkpoints_epochs > 0 and i_step > 0 :
485480 if (i_epoch + 1 ) % save_checkpoints_epochs == 0 :
486481 last_ckpt_step = i_step
487- ckpt_dir = os .path .join (model_dir , f"model.ckpt-{ i_step } " )
488- checkpoint_util .save_model (
489- ckpt_dir ,
490- model ,
491- optimizer ,
492- )
493- checkpoint_util .save_dataloader_state (ckpt_dir , dataloader_state )
482+ ckpt_manager .save (i_step , model , optimizer , dataloader_state )
494483 if eval_dataloader is not None :
495484 _evaluate (
496485 model ,
@@ -525,13 +514,7 @@ def _train_and_evaluate(
525514 if train_config .is_profiling :
526515 prof .stop ()
527516 if last_ckpt_step != i_step :
528- ckpt_dir = os .path .join (model_dir , f"model.ckpt-{ i_step } " )
529- checkpoint_util .save_model (
530- ckpt_dir ,
531- model ,
532- optimizer ,
533- )
534- checkpoint_util .save_dataloader_state (ckpt_dir , dataloader_state )
517+ ckpt_manager .save (i_step , model , optimizer , dataloader_state )
535518 if eval_dataloader is not None :
536519 _evaluate (
537520 model ,
@@ -544,6 +527,7 @@ def _train_and_evaluate(
544527 check_all_workers_data_status = check_all_workers_data_status ,
545528 )
546529 model .train ()
530+ ckpt_manager .close ()
547531
548532
549533def train_and_evaluate (
@@ -610,10 +594,17 @@ def train_and_evaluate(
610594 gl_cluster = gl_cluster ,
611595 )
612596
597+ ckpt_manager = checkpoint_util .CheckpointManager (
598+ pipeline_config .model_dir ,
599+ keep_checkpoint_max = train_config .keep_checkpoint_max ,
600+ export_config = pipeline_config .export_config ,
601+ )
602+
613603 # Get Restore Ckpt Path
614604 ckpt_path = None
615605 skip_steps = - 1
616606 if pipeline_config .train_config .fine_tune_checkpoint :
607+ # fine_tune_checkpoint is an external dir, outside the manager's model_dir.
617608 ckpt_path , _ = checkpoint_util .latest_checkpoint (
618609 pipeline_config .train_config .fine_tune_checkpoint
619610 )
@@ -624,9 +615,7 @@ def train_and_evaluate(
624615 )
625616 if os .path .exists (pipeline_config .model_dir ):
626617 # Restore dataloader state if continuing training
627- latest_ckpt_path , skip_steps = checkpoint_util .latest_checkpoint (
628- pipeline_config .model_dir
629- )
618+ latest_ckpt_path , skip_steps = ckpt_manager .latest_checkpoint ()
630619 if latest_ckpt_path :
631620 if continue_train :
632621 ckpt_path = latest_ckpt_path
@@ -641,7 +630,7 @@ def train_and_evaluate(
641630 # Restore dataloader checkpoint state
642631 dataloader_state : Optional [Dict [str , int ]] = None
643632 if ckpt_path and continue_train :
644- dataloader_state = checkpoint_util .restore_dataloader_state (ckpt_path )
633+ dataloader_state = ckpt_manager .restore_dataloader_state (ckpt_path )
645634 if dataloader_state :
646635 train_dataloader .dataset .load_state_dict (dataloader_state )
647636
@@ -778,6 +767,7 @@ def train_and_evaluate(
778767 pipeline_config .model_dir ,
779768 train_config = train_config ,
780769 eval_config = pipeline_config .eval_config ,
770+ ckpt_manager = ckpt_manager ,
781771 skip_steps = skip_steps ,
782772 ckpt_path = ckpt_path ,
783773 check_all_workers_data_status = check_all_workers_data_status ,
@@ -837,11 +827,10 @@ def evaluate(
837827 model , device = device , mixed_precision = train_config .mixed_precision
838828 )
839829
830+ ckpt_manager = checkpoint_util .CheckpointManager (pipeline_config .model_dir )
840831 global_step = None
841832 if not checkpoint_path :
842- checkpoint_path , global_step = checkpoint_util .latest_checkpoint (
843- pipeline_config .model_dir
844- )
833+ checkpoint_path , global_step = ckpt_manager .latest_checkpoint ()
845834 planner = create_planner (
846835 device = device ,
847836 # pyre-ignore [16]
@@ -864,7 +853,7 @@ def evaluate(
864853 )
865854
866855 if checkpoint_path :
867- checkpoint_util . restore_model (checkpoint_path , model )
856+ ckpt_manager . restore (checkpoint_path , model )
868857 else :
869858 raise ValueError ("Eval checkpoint path should be specified." )
870859
@@ -947,17 +936,16 @@ def export(
947936 model = InferWrapper (model )
948937
949938 if not checkpoint_path :
939+ ckpt_manager = checkpoint_util .CheckpointManager (
940+ pipeline_config .model_dir , export_config = pipeline_config .export_config
941+ )
950942 if (
951943 pipeline_config .HasField ("export_config" )
952944 and pipeline_config .export_config .exporter_type == "best"
953945 ):
954- checkpoint_path , _ = checkpoint_util .best_checkpoint (
955- pipeline_config .model_dir , pipeline_config .export_config
956- )
946+ checkpoint_path , _ = ckpt_manager .best_checkpoint ()
957947 else :
958- checkpoint_path , _ = checkpoint_util .latest_checkpoint (
959- pipeline_config .model_dir
960- )
948+ checkpoint_path , _ = ckpt_manager .latest_checkpoint ()
961949
962950 if isinstance (model .model , MatchModel ):
963951 for name , module in model .model .named_children ():
@@ -1386,11 +1374,10 @@ def predict_checkpoint(
13861374 output_cols = output_cols ,
13871375 )
13881376
1377+ ckpt_manager = checkpoint_util .CheckpointManager (pipeline_config .model_dir )
13891378 global_step = None
13901379 if not checkpoint_path :
1391- checkpoint_path , global_step = checkpoint_util .latest_checkpoint (
1392- pipeline_config .model_dir
1393- )
1380+ checkpoint_path , global_step = ckpt_manager .latest_checkpoint ()
13941381 planner = create_planner (
13951382 device = device ,
13961383 # pyre-ignore [16]
@@ -1413,7 +1400,7 @@ def predict_checkpoint(
14131400 model .eval ()
14141401
14151402 if checkpoint_path :
1416- checkpoint_util . restore_model (checkpoint_path , model )
1403+ ckpt_manager . restore (checkpoint_path , model )
14171404 else :
14181405 raise ValueError ("Predict checkpoint path should be specified." )
14191406
0 commit comments