@@ -457,75 +457,38 @@ def save_sharded_safetensors(
457457# --------------------------- SCRIPT -------------------------
458458
459459
460- # have it serve as a conversion script
461- if __name__ == "__main__" :
462- # Standard
463- import argparse
464-
465- parser = argparse .ArgumentParser (
466- description = (
467- "Utility for converting ScatterMoE checkpoint back to the "
468- "orginal state dict format. "
469- "The ScatterMoE checkpoint was saved after the pretrained model "
470- "had been converted by a module swap, hence the state dict will "
471- "no longer resemble the original. This utility creaes"
472- )
473- )
474-
475- parser .add_argument (
476- "checkpoint_dir" ,
477- help = "Path to the checkpoint." ,
478- )
479-
480- parser .add_argument (
481- "output_dir" , help = "Path to the location to write the converted checkpoint."
482- )
483-
484- parser .add_argument (
485- "pretrained_model_name_or_path" ,
486- help = (
487- "In order to reconstruct the state dict, we requre hints from "
488- "the original pretrained model checkpoint (from which this "
489- "checkpoint is obtained)."
490- ),
491- default = None ,
492- )
493-
494- args = parser .parse_args ()
495-
496- # search for an FSDP checkpoint. If it is an FSDP checkpoint, it must
497- # start with FSDP_MODEL_NAME
498- if args .checkpoint_dir .startswith (FSDP_MODEL_NAME ):
499- checkpoint_dir = args .checkpoint_dir
460+ def recover_safetensors_from_dcp (
461+ checkpoint_dir , pretrained_model_name_or_path , output_dir
462+ ):
463+ if checkpoint_dir .startswith (FSDP_MODEL_NAME ):
500464 loader = get_state_dict_from_dcp_checkpoint
501465 else :
502- checkpoint_dir = [
466+ fsdp_checkpoint_dirs = [
503467 x
504- for x in os .listdir (args . checkpoint_dir )
505- if os .path .isdir (os .path .join (args . checkpoint_dir , x ))
468+ for x in os .listdir (checkpoint_dir )
469+ if os .path .isdir (os .path .join (checkpoint_dir , x ))
506470 and x .startswith (FSDP_MODEL_NAME )
507471 ]
508- if len (checkpoint_dir ) == 1 :
509- checkpoint_dir = os .path .join (args . checkpoint_dir , checkpoint_dir [0 ])
472+ if len (fsdp_checkpoint_dirs ) == 1 :
473+ checkpoint_dir = os .path .join (checkpoint_dir , fsdp_checkpoint_dirs [0 ])
510474 loader = get_state_dict_from_dcp_checkpoint
511- elif len (checkpoint_dir ) > 1 :
475+ elif len (fsdp_checkpoint_dirs ) > 1 :
512476 raise ValueError (
513- f"Found > 1 dirs in dcp checkpoint dir { args . checkpoint_dir } "
477+ f"Found > 1 dirs in dcp checkpoint dir { checkpoint_dir } "
514478 f"that starts with { FSDP_MODEL_NAME } . Please spectify the exact dir."
515479 )
516480 else :
517481 # then take it as a safetensors checkpoint
518482 # - do not support .bin checkpoints
519- checkpoint_dir = args .checkpoint_dir
520483 loader = get_state_dict_from_safe_checkpoint
521484
522485 # - pretrained model name
523- _name_or_path = args . pretrained_model_name_or_path
486+ _name_or_path = pretrained_model_name_or_path
524487
525488 # assume output directory exists, we do not create it
526489 # - copy the config file if exists
527490 config_file = os .path .join (checkpoint_dir , CONFIG_NAME )
528- target_config_file = os .path .join (args . output_dir , CONFIG_NAME )
491+ target_config_file = os .path .join (output_dir , CONFIG_NAME )
529492 if os .path .exists (config_file ):
530493 shutil .copyfile (config_file , target_config_file )
531494
@@ -544,6 +507,46 @@ def save_sharded_safetensors(
544507 # save it as a safetensors file
545508 save_sharded_safetensors (
546509 {k : v .contiguous () for k , v in state_dict .items ()},
547- args . output_dir ,
510+ output_dir ,
548511 metadata = {"format" : "pt" },
549512 )
513+
514+
515+ # have it serve as a conversion script
516+ if __name__ == "__main__" :
517+ # Standard
518+ import argparse
519+
520+ parser = argparse .ArgumentParser (
521+ description = (
522+ "Utility for converting ScatterMoE checkpoint back to the "
523+ "orginal state dict format. "
524+ "The ScatterMoE checkpoint was saved after the pretrained model "
525+ "had been converted by a module swap, hence the state dict will "
526+ "no longer resemble the original. This utility creaes"
527+ )
528+ )
529+
530+ parser .add_argument (
531+ "checkpoint_dir" ,
532+ help = "Path to the checkpoint." ,
533+ )
534+
535+ parser .add_argument (
536+ "output_dir" , help = "Path to the location to write the converted checkpoint."
537+ )
538+
539+ parser .add_argument (
540+ "pretrained_model_name_or_path" ,
541+ help = (
542+ "In order to reconstruct the state dict, we requre hints from "
543+ "the original pretrained model checkpoint (from which this "
544+ "checkpoint is obtained)."
545+ ),
546+ default = None ,
547+ )
548+
549+ args = parser .parse_args ()
550+ recover_safetensors_from_dcp (
551+ args .checkpoint_dir , args .pretrained_model_name_or_path , args .output_dir
552+ )
0 commit comments