Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/usage/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ LR策略可以支持按epoch更新或者按step更新
- num_epochs: 训练的epoch数
- save_checkpoints_steps: 保存模型的步数间隔,保存模型后会做一次评估
- save_checkpoints_epochs: 保存模型的Epoch数间隔,保存模型后会做一次评估,与save_checkpoints_steps不能同时设置
- keep_checkpoint_max: 最多保留的最近checkpoint数量,超出后会异步删除最旧的checkpoint,默认0表示全部保留;当exporter_type为best时,会额外保留指标最优的checkpoint
- fine_tune_checkpoint: 增量训练的checkpoint路径,也可以设置checkpoint目录,将会使用目录下最新的checkpoint
- fine_tune_ckpt_var_map: 需要restore的参数列表文件路径,文件的每一行是{variable_name in current model}\\t{variable name in old model ckpt}
- 需要设置fine_tune_ckpt_var_map的情形:
Expand Down
69 changes: 28 additions & 41 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def _train_and_evaluate(
model_dir: str,
train_config: TrainConfig,
eval_config: EvalConfig,
ckpt_manager: checkpoint_util.CheckpointManager,
skip_steps: int = -1,
ckpt_path: Optional[str] = None,
eval_result_filename: str = TRAIN_EVAL_RESULT_FILENAME,
Expand Down Expand Up @@ -412,7 +413,7 @@ def _train_and_evaluate(
# Restore model and optimizer checkpoint
if i_step == 0 and ckpt_path is not None:
if ignore_restore_optimizer:
checkpoint_util.restore_model(
ckpt_manager.restore(
ckpt_path, model, None, train_config.fine_tune_ckpt_param_map
)
else:
Expand All @@ -421,7 +422,7 @@ def _train_and_evaluate(
peek_batch = next(train_iterator)
pipeline.progress(iter([peek_batch]))
train_iterator = itertools.chain([peek_batch], train_iterator)
checkpoint_util.restore_model(
ckpt_manager.restore(
ckpt_path, model, optimizer, train_config.fine_tune_ckpt_param_map
)

Expand Down Expand Up @@ -459,13 +460,7 @@ def _train_and_evaluate(
if save_checkpoints_steps > 0 and i_step > 0:
if i_step % save_checkpoints_steps == 0:
last_ckpt_step = i_step
ckpt_dir = os.path.join(model_dir, f"model.ckpt-{i_step}")
checkpoint_util.save_model(
ckpt_dir,
model,
optimizer,
)
checkpoint_util.save_dataloader_state(ckpt_dir, dataloader_state)
ckpt_manager.save(i_step, model, optimizer, dataloader_state)
if eval_dataloader is not None:
_evaluate(
model,
Expand All @@ -484,13 +479,7 @@ def _train_and_evaluate(
if save_checkpoints_epochs > 0 and i_step > 0:
if (i_epoch + 1) % save_checkpoints_epochs == 0:
last_ckpt_step = i_step
ckpt_dir = os.path.join(model_dir, f"model.ckpt-{i_step}")
checkpoint_util.save_model(
ckpt_dir,
model,
optimizer,
)
checkpoint_util.save_dataloader_state(ckpt_dir, dataloader_state)
ckpt_manager.save(i_step, model, optimizer, dataloader_state)
if eval_dataloader is not None:
_evaluate(
model,
Expand Down Expand Up @@ -525,13 +514,7 @@ def _train_and_evaluate(
if train_config.is_profiling:
prof.stop()
if last_ckpt_step != i_step:
ckpt_dir = os.path.join(model_dir, f"model.ckpt-{i_step}")
checkpoint_util.save_model(
ckpt_dir,
model,
optimizer,
)
checkpoint_util.save_dataloader_state(ckpt_dir, dataloader_state)
ckpt_manager.save(i_step, model, optimizer, dataloader_state)
if eval_dataloader is not None:
_evaluate(
model,
Expand All @@ -544,6 +527,7 @@ def _train_and_evaluate(
check_all_workers_data_status=check_all_workers_data_status,
)
model.train()
ckpt_manager.close()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

close() only runs if the training loop completes normally. If any step between the first save() (which starts the daemon worker) and here raises, close() is skipped: the final coalesced prune pass is abandoned and the worker thread is leaked (harmless across process exit since it's a daemon, but it breaks the documented "on-disk state settled before export reads model_dir" contract if anything downstream runs in the same process). Consider wrapping the loop body so close() runs in a finally.



def train_and_evaluate(
Expand Down Expand Up @@ -610,10 +594,17 @@ def train_and_evaluate(
gl_cluster=gl_cluster,
)

ckpt_manager = checkpoint_util.CheckpointManager(
pipeline_config.model_dir,
keep_checkpoint_max=train_config.keep_checkpoint_max,
export_config=pipeline_config.export_config,
)

# Get Restore Ckpt Path
ckpt_path = None
skip_steps = -1
if pipeline_config.train_config.fine_tune_checkpoint:
# fine_tune_checkpoint is an external dir, outside the manager's model_dir.
ckpt_path, _ = checkpoint_util.latest_checkpoint(
pipeline_config.train_config.fine_tune_checkpoint
)
Expand All @@ -624,9 +615,7 @@ def train_and_evaluate(
)
if os.path.exists(pipeline_config.model_dir):
# Restore dataloader state if continuing training
latest_ckpt_path, skip_steps = checkpoint_util.latest_checkpoint(
pipeline_config.model_dir
)
latest_ckpt_path, skip_steps = ckpt_manager.latest_checkpoint()
if latest_ckpt_path:
if continue_train:
ckpt_path = latest_ckpt_path
Expand All @@ -641,7 +630,7 @@ def train_and_evaluate(
# Restore dataloader checkpoint state
dataloader_state: Optional[Dict[str, int]] = None
if ckpt_path and continue_train:
dataloader_state = checkpoint_util.restore_dataloader_state(ckpt_path)
dataloader_state = ckpt_manager.restore_dataloader_state(ckpt_path)
if dataloader_state:
train_dataloader.dataset.load_state_dict(dataloader_state)

Expand Down Expand Up @@ -778,6 +767,7 @@ def train_and_evaluate(
pipeline_config.model_dir,
train_config=train_config,
eval_config=pipeline_config.eval_config,
ckpt_manager=ckpt_manager,
skip_steps=skip_steps,
ckpt_path=ckpt_path,
check_all_workers_data_status=check_all_workers_data_status,
Expand Down Expand Up @@ -837,11 +827,10 @@ def evaluate(
model, device=device, mixed_precision=train_config.mixed_precision
)

ckpt_manager = checkpoint_util.CheckpointManager(pipeline_config.model_dir)
global_step = None
if not checkpoint_path:
checkpoint_path, global_step = checkpoint_util.latest_checkpoint(
pipeline_config.model_dir
)
checkpoint_path, global_step = ckpt_manager.latest_checkpoint()
planner = create_planner(
device=device,
# pyre-ignore [16]
Expand All @@ -864,7 +853,7 @@ def evaluate(
)

if checkpoint_path:
checkpoint_util.restore_model(checkpoint_path, model)
ckpt_manager.restore(checkpoint_path, model)
else:
raise ValueError("Eval checkpoint path should be specified.")

Expand Down Expand Up @@ -947,17 +936,16 @@ def export(
model = InferWrapper(model)

if not checkpoint_path:
ckpt_manager = checkpoint_util.CheckpointManager(
pipeline_config.model_dir, export_config=pipeline_config.export_config
)
if (
pipeline_config.HasField("export_config")
and pipeline_config.export_config.exporter_type == "best"
):
checkpoint_path, _ = checkpoint_util.best_checkpoint(
pipeline_config.model_dir, pipeline_config.export_config
)
checkpoint_path, _ = ckpt_manager.best_checkpoint()
else:
checkpoint_path, _ = checkpoint_util.latest_checkpoint(
pipeline_config.model_dir
)
checkpoint_path, _ = ckpt_manager.latest_checkpoint()

if isinstance(model.model, MatchModel):
for name, module in model.model.named_children():
Expand Down Expand Up @@ -1386,11 +1374,10 @@ def predict_checkpoint(
output_cols=output_cols,
)

ckpt_manager = checkpoint_util.CheckpointManager(pipeline_config.model_dir)
global_step = None
if not checkpoint_path:
checkpoint_path, global_step = checkpoint_util.latest_checkpoint(
pipeline_config.model_dir
)
checkpoint_path, global_step = ckpt_manager.latest_checkpoint()
planner = create_planner(
device=device,
# pyre-ignore [16]
Expand All @@ -1413,7 +1400,7 @@ def predict_checkpoint(
model.eval()

if checkpoint_path:
checkpoint_util.restore_model(checkpoint_path, model)
ckpt_manager.restore(checkpoint_path, model)
else:
raise ValueError("Predict checkpoint path should be specified.")

Expand Down
2 changes: 2 additions & 0 deletions tzrec/protos/train.proto
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,7 @@ message TrainConfig {
optional uint32 gradient_accumulation_steps = 18;
// dense gradient clipping config
optional GradClipping grad_clipping = 19;
// maximum number of recent checkpoints to keep; 0 keeps all.
optional uint32 keep_checkpoint_max = 20 [default = 0];

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment reads as a hard cap, but when export_config.exporter_type == "best" the best checkpoint is retained in addition to the N recent ones, so the effective count can be N+1. The class docstring and train.md both note this; the proto comment is the only place that omits it.

Suggested change
optional uint32 keep_checkpoint_max = 20 [default = 0];
// max number of recent checkpoints to keep; 0 keeps all. When
// export_config.exporter_type is "best", the best checkpoint is also retained.
optional uint32 keep_checkpoint_max = 20 [default = 0];

// TBD: qcomm config
}
Loading
Loading