1313
1414import torch
1515import torch .distributed as dist
16+ from accelerate import skip_first_batches
1617from accelerate .utils import set_seed
1718from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
1819from torch .distributed .fsdp import MixedPrecision , ShardingStrategy , StateDictType
@@ -281,7 +282,6 @@ def build_dataloader(
281282 args .batch_size ,
282283 num_workers = args .dataloader_num_workers ,
283284 shuffle = False ,
284- process_group = get_dp_group (),
285285 is_vlm = args .is_vlm ,
286286 )
287287
@@ -291,10 +291,22 @@ def build_dataloader(
291291def save_checkpoint (args , epoch , step , dflash_model , draft_model , optimizer ):
292292 """Save checkpoint."""
293293 save_dir = os .path .join (args .output_dir , f"epoch_{ epoch } _step_{ step } " )
294- if dist .get_rank () == 0 :
294+ rank = dist .get_rank ()
295+
296+ if rank == 0 :
295297 os .makedirs (save_dir , exist_ok = True )
296298 dist .barrier ()
297299
300+ torch .save (
301+ {
302+ "epoch" : epoch ,
303+ "global_step" : step ,
304+ "args" : args ,
305+ ** optimizer .state_dict (),
306+ },
307+ os .path .join (save_dir , f"training_state_rank_{ rank } .pt" ),
308+ )
309+
298310 with FSDP .state_dict_type (dflash_model , StateDictType .FULL_STATE_DICT ):
299311 state_dict = dflash_model .state_dict ()
300312 draft_state_dict = {
@@ -303,17 +315,7 @@ def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer):
303315 if "draft_model." in k
304316 }
305317
306- if dist .get_rank () == 0 :
307- torch .save (
308- {
309- "epoch" : epoch ,
310- "global_step" : step ,
311- "args" : args ,
312- ** optimizer .state_dict (),
313- },
314- os .path .join (save_dir , "training_state.pt" ),
315- )
316-
318+ if rank == 0 :
317319 draft_model .save_pretrained (save_dir , state_dict = draft_state_dict )
318320
319321 modeling_src = os .path .join (
@@ -377,43 +379,52 @@ def main():
377379 init_distributed (timeout = args .dist_timeout , tp_size = args .tp_size )
378380 print_with_rank ("Initialized distributed" )
379381
382+ target_model , draft_model = build_models (args )
383+
380384 draft_model_last_checkpoint = None
381- ckpt_info = (0 , 0 )
382- if args .resume and os .path .isdir (args .output_dir ):
383- draft_model_last_checkpoint , ckpt_info = get_last_checkpoint (args .output_dir )
384- print (f"Last checkpoint detected: { draft_model_last_checkpoint } " )
385+ if args .ckpt_dir is not None :
386+ if os .path .isdir (args .ckpt_dir ):
387+ draft_model_last_checkpoint = args .ckpt_dir
388+ print_on_rank0 (f"Using checkpoint: { draft_model_last_checkpoint } " )
389+ else :
390+ raise ValueError (
391+ f"Provided ckpt dir { args .ckpt_dir } is not a valid directory."
392+ )
385393
386- # If resuming, load config from checkpoint to ensure consistency
387- if draft_model_last_checkpoint :
388- checkpoint_config_path = os .path .join (
389- draft_model_last_checkpoint , "config.json"
394+ start_epoch = 0
395+ global_step = 0
396+ ckpt_info = None
397+ if args .resume and os .path .isdir (args .output_dir ):
398+ draft_model_last_checkpoint , ckpt_info = get_last_checkpoint (
399+ args .output_dir , prefix = "epoch_"
390400 )
391- if os .path .exists (checkpoint_config_path ):
392- print (f"Loading draft config from checkpoint: { checkpoint_config_path } " )
393- args .draft_config_path = checkpoint_config_path
394-
395- target_model , draft_model = build_models (args )
401+ if ckpt_info :
402+ start_epoch = ckpt_info [0 ]
403+ global_step = ckpt_info [1 ]
404+ print_on_rank0 (f"Last checkpoint detected: { draft_model_last_checkpoint } " )
396405
406+ rank = dist .get_rank ()
397407 resume_state = None
408+
398409 if draft_model_last_checkpoint :
399410 loaded_model = DFlashDraftModel .from_pretrained (
400411 draft_model_last_checkpoint , torch_dtype = torch .bfloat16
401412 )
402413 draft_model .load_state_dict (loaded_model .state_dict ())
403414 del loaded_model
404- print ("Loaded draft model weights from checkpoint" )
415+ print_on_rank0 ("Loaded draft model weights from checkpoint" )
405416
406417 training_state_path = os .path .join (
407- draft_model_last_checkpoint , "training_state .pt"
418+ draft_model_last_checkpoint , f"training_state_rank_ { rank } .pt"
408419 )
420+
409421 if os .path .exists (training_state_path ):
410422 resume_state = torch .load (
411423 training_state_path , map_location = "cpu" , weights_only = False
412424 )
413- print (
414- f"Will resume from epoch { resume_state ['epoch' ]} , "
415- f"step { resume_state ['global_step' ]} "
416- )
425+ print (f"[Rank { rank } ] Found and loading state from { training_state_path } " )
426+ else :
427+ print (f"[Rank { rank } ] Warning: { training_state_path } not found!" )
417428
418429 tokenizer = AutoTokenizer .from_pretrained (
419430 args .target_model_path , trust_remote_code = args .trust_remote_code
@@ -483,9 +494,6 @@ def main():
483494 )
484495 print_with_rank ("Initialized FSDP" )
485496
486- start_epoch = ckpt_info [0 ]
487- global_step = ckpt_info [1 ]
488-
489497 optimizer = BF16Optimizer (
490498 draft_model ,
491499 lr = args .learning_rate ,
@@ -495,7 +503,7 @@ def main():
495503 )
496504
497505 if resume_state is not None :
498- optimizer .scheduler . load_state_dict (resume_state [ "scheduler_state_dict" ] )
506+ optimizer .load_state_dict (resume_state )
499507 start_epoch = resume_state ["epoch" ]
500508 global_step = resume_state ["global_step" ]
501509 del resume_state
@@ -518,16 +526,31 @@ def main():
518526 train_dataloader .sampler .set_epoch (epoch )
519527 draft_model .train ()
520528
529+ steps_to_skip_this_epoch = 0
530+ if epoch == start_epoch and skip_steps > 0 :
531+ steps_to_skip_this_epoch = skip_steps
532+ print_on_rank0 (
533+ f"Fast-forwarding DataLoader, skipping first { steps_to_skip_this_epoch } batches..."
534+ )
535+ active_dataloader = skip_first_batches (
536+ train_dataloader , steps_to_skip_this_epoch
537+ )
538+ total_batches = len (train_dataloader ) - steps_to_skip_this_epoch
539+ else :
540+ active_dataloader = train_dataloader
541+ total_batches = len (train_dataloader )
542+
521543 if dist .get_rank () == 0 :
522544 progress_bar = tqdm (
523- train_dataloader , desc = f"Training Epoch { epoch } " , leave = True
545+ active_dataloader ,
546+ total = total_batches ,
547+ desc = f"Training Epoch { epoch } " ,
548+ leave = True ,
524549 )
525550 else :
526- progress_bar = train_dataloader
551+ progress_bar = active_dataloader
527552
528- for step_in_epoch , data in enumerate (progress_bar ):
529- if epoch == start_epoch and step_in_epoch < skip_steps :
530- continue
553+ for _ , data in enumerate (progress_bar ):
531554 global_step += 1
532555
533556 input_ids_cpu = data ["input_ids" ]
@@ -594,6 +617,7 @@ def main():
594617 {
595618 "loss" : f"{ loss .item ():.4f} " ,
596619 "acc" : f"{ accuracy .item ():.4f} " ,
620+ "lr" : f"{ optimizer .get_learning_rate ():.2e} " ,
597621 "iter_time" : f"{ elapsed :.2f} s" ,
598622 }
599623 )
0 commit comments