Skip to content

Commit 8222023

Browse files
committed
Add support for saving intermediate models
1 parent 3f52a87 commit 8222023

3 files changed

Lines changed: 15 additions & 0 deletions

File tree

bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ lr_scheduler_kwargs:
6060
checkpoint:
6161
ckpt_dir: ???
6262
save_final_model: true
63+
save_final_model_with_checkpoint: false
6364
resume_from_checkpoint: true
6465
save_every_n_steps: 1_000
6566
max_checkpoints: 5

bionemo-recipes/recipes/codonfm_native_te/train_ddp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,13 @@ def main(args: DictConfig) -> float | None:
241241
dist_config=dist_config,
242242
max_checkpoints=args.checkpoint.max_checkpoints,
243243
)
244+
if args.checkpoint.save_final_model_with_checkpoint:
245+
save_final_model_ddp(
246+
model=model,
247+
config=config,
248+
save_directory=ckpt_path / f"step_{step}" / "final_model",
249+
dist_config=dist_config,
250+
)
244251

245252
if val_dataloader is not None and step > 0 and step % args.validation.eval_interval == 0:
246253
model.eval()

bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,13 @@ def main(args: DictConfig) -> float | None:
273273
dist_config=dist_config,
274274
max_checkpoints=args.checkpoint.max_checkpoints,
275275
)
276+
if args.checkpoint.save_final_model_with_checkpoint:
277+
save_final_model_fsdp2(
278+
model=model,
279+
config=config,
280+
save_directory=ckpt_path / f"step_{step}" / "final_model",
281+
dist_config=dist_config,
282+
)
276283

277284
if val_dataloader is not None and step > 0 and step % args.validation.eval_interval == 0:
278285
model.eval()

0 commit comments

Comments
 (0)