Skip to content

Commit f6ec513

Browse files
justadogistakenbaojiangnan
andauthored
fix ckpt dir check (#320)
Co-authored-by: baojiangnan <baojiangnan@kuaishou.com>
1 parent 95cb2ae commit f6ec513

1 file changed

Lines changed: 9 additions & 8 deletions

File tree

scripts/train_eagle3.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,15 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]
303303

304304
# Handle base ckpt, config file
305305
draft_model_last_checkpoint = None
306-
if args.ckpt_dir is not None and os.path.isdir(args.ckpt_dir):
307-
draft_model_config = os.path.join(args.ckpt_dir, "config.json")
308-
draft_model_last_checkpoint = args.ckpt_dir
309-
print_on_rank0(f"Finetuning from base model: {draft_model_last_checkpoint}")
310-
else:
311-
raise ValueError(
312-
f"Provided base model dir {args.ckpt_dir} is not a valid directory."
313-
)
306+
if args.ckpt_dir is not None:
307+
if os.path.isdir(args.ckpt_dir):
308+
draft_model_config = os.path.join(args.ckpt_dir, "config.json")
309+
draft_model_last_checkpoint = args.ckpt_dir
310+
print_on_rank0(f"Finetuning from base model: {draft_model_last_checkpoint}")
311+
else:
312+
raise ValueError(
313+
f"Provided base model dir {args.ckpt_dir} is not a valid directory."
314+
)
314315

315316
# detecting last ckpt for draft model
316317
if args.resume and os.path.isdir(args.output_dir):

0 commit comments

Comments
 (0)