Skip to content

Commit c9a0dc6

Browse files
committed
make last step save sharded ckpt as well
1 parent 23a9e39 commit c9a0dc6

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

fms_fsdp/utils/train_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def train(
166166
ddp_stats.zero_()
167167
torch.cuda.reset_peak_memory_stats(device=torch.cuda.current_device())
168168

169-
if batch_idx % cfg.checkpoint_interval == 0:
169+
if batch_idx % cfg.checkpoint_interval == 0 or batch_idx == cfg.num_steps:
170170
checkpointer.save(
171171
batch_idx,
172172
model,

0 commit comments

Comments
 (0)