Skip to content

Commit 8f89ee0

Browse files
authored
fix fsdp2 checkpoint load (#1375)
We need to pass the dataloader even if we're not loading it from a checkpoint, since it gets passed through as none otherwise. Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 9467487 commit 8f89ee0

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def main(args: DictConfig) -> float | None:
116116
scheduler=scheduler,
117117
ckpt_path=ckpt_path,
118118
dist_config=dist_config,
119-
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
119+
dataloader=train_dataloader,
120120
process_group=device_mesh.get_group("dp"),
121121
)
122122
logger.info(f"Checkpoint loaded, resuming from step {start_step}, epoch {epoch}")

0 commit comments

Comments
 (0)