Skip to content

Commit 3dd7170

Browse files
Mandy3311claude
andcommitted
fix: speed up checkpoint resume with skip_first_batches and per-rank optimizer state
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 7a85748 commit 3dd7170

1 file changed

Lines changed: 65 additions & 41 deletions

File tree

scripts/train_dflash.py

Lines changed: 65 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
import torch.distributed as dist
16+
from accelerate import skip_first_batches
1617
from accelerate.utils import set_seed
1718
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1819
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
@@ -281,7 +282,6 @@ def build_dataloader(
281282
args.batch_size,
282283
num_workers=args.dataloader_num_workers,
283284
shuffle=False,
284-
process_group=get_dp_group(),
285285
is_vlm=args.is_vlm,
286286
)
287287

@@ -291,10 +291,22 @@ def build_dataloader(
291291
def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer):
292292
"""Save checkpoint."""
293293
save_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}")
294-
if dist.get_rank() == 0:
294+
rank = dist.get_rank()
295+
296+
if rank == 0:
295297
os.makedirs(save_dir, exist_ok=True)
296298
dist.barrier()
297299

300+
torch.save(
301+
{
302+
"epoch": epoch,
303+
"global_step": step,
304+
"args": args,
305+
**optimizer.state_dict(),
306+
},
307+
os.path.join(save_dir, f"training_state_rank_{rank}.pt"),
308+
)
309+
298310
with FSDP.state_dict_type(dflash_model, StateDictType.FULL_STATE_DICT):
299311
state_dict = dflash_model.state_dict()
300312
draft_state_dict = {
@@ -303,17 +315,7 @@ def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer):
303315
if "draft_model." in k
304316
}
305317

306-
if dist.get_rank() == 0:
307-
torch.save(
308-
{
309-
"epoch": epoch,
310-
"global_step": step,
311-
"args": args,
312-
**optimizer.state_dict(),
313-
},
314-
os.path.join(save_dir, "training_state.pt"),
315-
)
316-
318+
if rank == 0:
317319
draft_model.save_pretrained(save_dir, state_dict=draft_state_dict)
318320

319321
modeling_src = os.path.join(
@@ -377,43 +379,52 @@ def main():
377379
init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size)
378380
print_with_rank("Initialized distributed")
379381

382+
target_model, draft_model = build_models(args)
383+
380384
draft_model_last_checkpoint = None
381-
ckpt_info = (0, 0)
382-
if args.resume and os.path.isdir(args.output_dir):
383-
draft_model_last_checkpoint, ckpt_info = get_last_checkpoint(args.output_dir)
384-
print(f"Last checkpoint detected: {draft_model_last_checkpoint}")
385+
if args.ckpt_dir is not None:
386+
if os.path.isdir(args.ckpt_dir):
387+
draft_model_last_checkpoint = args.ckpt_dir
388+
print_on_rank0(f"Using checkpoint: {draft_model_last_checkpoint}")
389+
else:
390+
raise ValueError(
391+
f"Provided ckpt dir {args.ckpt_dir} is not a valid directory."
392+
)
385393

386-
# If resuming, load config from checkpoint to ensure consistency
387-
if draft_model_last_checkpoint:
388-
checkpoint_config_path = os.path.join(
389-
draft_model_last_checkpoint, "config.json"
394+
start_epoch = 0
395+
global_step = 0
396+
ckpt_info = None
397+
if args.resume and os.path.isdir(args.output_dir):
398+
draft_model_last_checkpoint, ckpt_info = get_last_checkpoint(
399+
args.output_dir, prefix="epoch_"
390400
)
391-
if os.path.exists(checkpoint_config_path):
392-
print(f"Loading draft config from checkpoint: {checkpoint_config_path}")
393-
args.draft_config_path = checkpoint_config_path
394-
395-
target_model, draft_model = build_models(args)
401+
if ckpt_info:
402+
start_epoch = ckpt_info[0]
403+
global_step = ckpt_info[1]
404+
print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}")
396405

406+
rank = dist.get_rank()
397407
resume_state = None
408+
398409
if draft_model_last_checkpoint:
399410
loaded_model = DFlashDraftModel.from_pretrained(
400411
draft_model_last_checkpoint, torch_dtype=torch.bfloat16
401412
)
402413
draft_model.load_state_dict(loaded_model.state_dict())
403414
del loaded_model
404-
print("Loaded draft model weights from checkpoint")
415+
print_on_rank0("Loaded draft model weights from checkpoint")
405416

406417
training_state_path = os.path.join(
407-
draft_model_last_checkpoint, "training_state.pt"
418+
draft_model_last_checkpoint, f"training_state_rank_{rank}.pt"
408419
)
420+
409421
if os.path.exists(training_state_path):
410422
resume_state = torch.load(
411423
training_state_path, map_location="cpu", weights_only=False
412424
)
413-
print(
414-
f"Will resume from epoch {resume_state['epoch']}, "
415-
f"step {resume_state['global_step']}"
416-
)
425+
print(f"[Rank {rank}] Found and loading state from {training_state_path}")
426+
else:
427+
print(f"[Rank {rank}] Warning: {training_state_path} not found!")
417428

418429
tokenizer = AutoTokenizer.from_pretrained(
419430
args.target_model_path, trust_remote_code=args.trust_remote_code
@@ -483,9 +494,6 @@ def main():
483494
)
484495
print_with_rank("Initialized FSDP")
485496

486-
start_epoch = ckpt_info[0]
487-
global_step = ckpt_info[1]
488-
489497
optimizer = BF16Optimizer(
490498
draft_model,
491499
lr=args.learning_rate,
@@ -495,7 +503,7 @@ def main():
495503
)
496504

497505
if resume_state is not None:
498-
optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"])
506+
optimizer.load_state_dict(resume_state)
499507
start_epoch = resume_state["epoch"]
500508
global_step = resume_state["global_step"]
501509
del resume_state
@@ -518,16 +526,31 @@ def main():
518526
train_dataloader.sampler.set_epoch(epoch)
519527
draft_model.train()
520528

529+
steps_to_skip_this_epoch = 0
530+
if epoch == start_epoch and skip_steps > 0:
531+
steps_to_skip_this_epoch = skip_steps
532+
print_on_rank0(
533+
f"Fast-forwarding DataLoader, skipping first {steps_to_skip_this_epoch} batches..."
534+
)
535+
active_dataloader = skip_first_batches(
536+
train_dataloader, steps_to_skip_this_epoch
537+
)
538+
total_batches = len(train_dataloader) - steps_to_skip_this_epoch
539+
else:
540+
active_dataloader = train_dataloader
541+
total_batches = len(train_dataloader)
542+
521543
if dist.get_rank() == 0:
522544
progress_bar = tqdm(
523-
train_dataloader, desc=f"Training Epoch {epoch}", leave=True
545+
active_dataloader,
546+
total=total_batches,
547+
desc=f"Training Epoch {epoch}",
548+
leave=True,
524549
)
525550
else:
526-
progress_bar = train_dataloader
551+
progress_bar = active_dataloader
527552

528-
for step_in_epoch, data in enumerate(progress_bar):
529-
if epoch == start_epoch and step_in_epoch < skip_steps:
530-
continue
553+
for _, data in enumerate(progress_bar):
531554
global_step += 1
532555

533556
input_ids_cpu = data["input_ids"]
@@ -594,6 +617,7 @@ def main():
594617
{
595618
"loss": f"{loss.item():.4f}",
596619
"acc": f"{accuracy.item():.4f}",
620+
"lr": f"{optimizer.get_learning_rate():.2e}",
597621
"iter_time": f"{elapsed:.2f}s",
598622
}
599623
)

0 commit comments

Comments
 (0)