forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexport.py
More file actions
91 lines (75 loc) · 3.31 KB
/
export.py
File metadata and controls
91 lines (75 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""Export a GPTModel."""
import functools
import os
import sys
import warnings
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
import modelopt.torch.export as mtex
import torch
from megatron.post_training.arguments import add_modelopt_args
from megatron.post_training.checkpointing import load_modelopt_checkpoint
from megatron.post_training.model_builder import modelopt_gpt_mamba_builder
from megatron.training import get_args, get_model
from megatron.training.initialize import initialize_megatron
from megatron.training.utils import unwrap_model
from model_provider import model_provider
warnings.filterwarnings('ignore')
def add_modelopt_export_args(parser):
"""Add additional arguments for ModelOpt hf-like export."""
group = parser.add_argument_group(title='ModelOpt hf-like export')
group.add_argument(
"--export-extra-modules",
action="store_true",
help="Export extra modules such as Medusa, EAGLE, or MTP.",
)
group.add_argument(
"--pretrained-model-name",
type=str,
help="A pretrained model hosted inside a model repo on huggingface.co.",
)
group.add_argument("--export-dir", type=str, help="The target export path.")
add_modelopt_args(parser)
return parser
if __name__ == "__main__":
initialize_megatron(
extra_args_provider=add_modelopt_export_args,
args_defaults={
'tokenizer_type': 'HuggingFaceTokenizer',
'no_load_rng': True,
'no_load_optim': True,
},
)
args = get_args()
# Meta device initialization for ParallelLinear only works if using cpu initialization.
# Meta device initialization is used such that models can be materialized in low-precision
# directly when ModelOpt real quant is used. Otherwise, the model is first initialized
# as BF16 in memory which may result in OOM and defeat the purpose of real quant.
args.use_cpu_initialization = True
if not args.init_model_with_meta_device:
warnings.warn(
"--init-model-with-meta-device is not set. If you would like to resume the "
"model in low-bit directly (low-memory initialization and skipping 16-bit), "
"--init-model-with-meta-device must be set.",
UserWarning,
)
model = get_model(functools.partial(model_provider, modelopt_gpt_mamba_builder), wrap_with_ddp=False)
# Materialize the model from meta device to cpu before loading the checkpoint.
unwrapped_model = unwrap_model(model)[0]
unwrapped_model.to_empty(device="cpu")
if args.load is not None:
_ = load_modelopt_checkpoint(model)
# Decide whether we are exporting only the extra_modules (e.g. EAGLE3).
# Only the last pp stage may have extra_modules, hence broadcast from the last rank.
export_extra_modules = hasattr(unwrapped_model, "eagle_module") or hasattr(unwrapped_model, "medusa_heads")
torch.distributed.broadcast_object_list(
[export_extra_modules],
src=torch.distributed.get_world_size() - 1,
)
mtex.export_mcore_gpt_to_hf(
unwrapped_model,
args.pretrained_model_name,
export_extra_modules=export_extra_modules,
dtype=torch.bfloat16,
export_dir=args.export_dir,
)