Skip to content

Commit f7210f7

Browse files
authored
feat: support moe hf chkpt (#133)
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent de9a4f1 commit f7210f7

2 files changed

Lines changed: 58 additions & 52 deletions

File tree

plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
# limitations under the License.
1414

1515
# Local
16-
from .checkpoint_utils import patch_huggingface_save_and_load_for_dtensors
16+
from .checkpoint_utils import (
17+
patch_huggingface_save_and_load_for_dtensors,
18+
recover_safetensors_from_dcp,
19+
)
1720
from .scattermoe_prepare import prepare_scattermoe
1821

1922
# this is a special patch function to disable foreach for

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -457,75 +457,38 @@ def save_sharded_safetensors(
457457
# --------------------------- SCRIPT -------------------------
458458

459459

460-
# have it serve as a conversion script
461-
if __name__ == "__main__":
462-
# Standard
463-
import argparse
464-
465-
parser = argparse.ArgumentParser(
466-
description=(
467-
"Utility for converting ScatterMoE checkpoint back to the "
468-
"orginal state dict format. "
469-
"The ScatterMoE checkpoint was saved after the pretrained model "
470-
"had been converted by a module swap, hence the state dict will "
471-
"no longer resemble the original. This utility creaes"
472-
)
473-
)
474-
475-
parser.add_argument(
476-
"checkpoint_dir",
477-
help="Path to the checkpoint.",
478-
)
479-
480-
parser.add_argument(
481-
"output_dir", help="Path to the location to write the converted checkpoint."
482-
)
483-
484-
parser.add_argument(
485-
"pretrained_model_name_or_path",
486-
help=(
487-
"In order to reconstruct the state dict, we requre hints from "
488-
"the original pretrained model checkpoint (from which this "
489-
"checkpoint is obtained)."
490-
),
491-
default=None,
492-
)
493-
494-
args = parser.parse_args()
495-
496-
# search for an FSDP checkpoint. If it is an FSDP checkpoint, it must
497-
# start with FSDP_MODEL_NAME
498-
if args.checkpoint_dir.startswith(FSDP_MODEL_NAME):
499-
checkpoint_dir = args.checkpoint_dir
460+
def recover_safetensors_from_dcp(
461+
checkpoint_dir, pretrained_model_name_or_path, output_dir
462+
):
463+
if checkpoint_dir.startswith(FSDP_MODEL_NAME):
500464
loader = get_state_dict_from_dcp_checkpoint
501465
else:
502-
checkpoint_dir = [
466+
fsdp_checkpoint_dirs = [
503467
x
504-
for x in os.listdir(args.checkpoint_dir)
505-
if os.path.isdir(os.path.join(args.checkpoint_dir, x))
468+
for x in os.listdir(checkpoint_dir)
469+
if os.path.isdir(os.path.join(checkpoint_dir, x))
506470
and x.startswith(FSDP_MODEL_NAME)
507471
]
508-
if len(checkpoint_dir) == 1:
509-
checkpoint_dir = os.path.join(args.checkpoint_dir, checkpoint_dir[0])
472+
if len(fsdp_checkpoint_dirs) == 1:
473+
checkpoint_dir = os.path.join(checkpoint_dir, fsdp_checkpoint_dirs[0])
510474
loader = get_state_dict_from_dcp_checkpoint
511-
elif len(checkpoint_dir) > 1:
475+
elif len(fsdp_checkpoint_dirs) > 1:
512476
raise ValueError(
513-
f"Found > 1 dirs in dcp checkpoint dir {args.checkpoint_dir} "
477+
f"Found > 1 dirs in dcp checkpoint dir {checkpoint_dir} "
514478
f"that starts with {FSDP_MODEL_NAME}. Please spectify the exact dir."
515479
)
516480
else:
517481
# then take it as a safetensors checkpoint
518482
# - do not support .bin checkpoints
519-
checkpoint_dir = args.checkpoint_dir
520483
loader = get_state_dict_from_safe_checkpoint
521484

522485
# - pretrained model name
523-
_name_or_path = args.pretrained_model_name_or_path
486+
_name_or_path = pretrained_model_name_or_path
524487

525488
# assume output directory exists, we do not create it
526489
# - copy the config file if exists
527490
config_file = os.path.join(checkpoint_dir, CONFIG_NAME)
528-
target_config_file = os.path.join(args.output_dir, CONFIG_NAME)
491+
target_config_file = os.path.join(output_dir, CONFIG_NAME)
529492
if os.path.exists(config_file):
530493
shutil.copyfile(config_file, target_config_file)
531494

@@ -544,6 +507,46 @@ def save_sharded_safetensors(
544507
# save it as a safetensors file
545508
save_sharded_safetensors(
546509
{k: v.contiguous() for k, v in state_dict.items()},
547-
args.output_dir,
510+
output_dir,
548511
metadata={"format": "pt"},
549512
)
513+
514+
515+
# have it serve as a conversion script
516+
if __name__ == "__main__":
517+
# Standard
518+
import argparse
519+
520+
parser = argparse.ArgumentParser(
521+
description=(
522+
"Utility for converting ScatterMoE checkpoint back to the "
523+
"orginal state dict format. "
524+
"The ScatterMoE checkpoint was saved after the pretrained model "
525+
"had been converted by a module swap, hence the state dict will "
526+
"no longer resemble the original. This utility creaes"
527+
)
528+
)
529+
530+
parser.add_argument(
531+
"checkpoint_dir",
532+
help="Path to the checkpoint.",
533+
)
534+
535+
parser.add_argument(
536+
"output_dir", help="Path to the location to write the converted checkpoint."
537+
)
538+
539+
parser.add_argument(
540+
"pretrained_model_name_or_path",
541+
help=(
542+
"In order to reconstruct the state dict, we requre hints from "
543+
"the original pretrained model checkpoint (from which this "
544+
"checkpoint is obtained)."
545+
),
546+
default=None,
547+
)
548+
549+
args = parser.parse_args()
550+
recover_safetensors_from_dcp(
551+
args.checkpoint_dir, args.pretrained_model_name_or_path, args.output_dir
552+
)

0 commit comments

Comments
 (0)