-
Notifications
You must be signed in to change notification settings - Fork 399
Expand file tree
/
Copy pathmegatron_importer.py
More file actions
831 lines (739 loc) · 37.3 KB
/
megatron_importer.py
File metadata and controls
831 lines (739 loc) · 37.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code that export quantized Megatron Core models for deployment."""
import tempfile
from pathlib import Path
import torch
import torch.distributed as dist
from huggingface_hub import snapshot_download
from tqdm import tqdm
from modelopt.torch.utils import import_plugin
from .mcore_common import all_mcore_hf_import_mapping
from .mcore_custom import (
CustomModuleMapping,
ParallelConfig,
dequantize_mxfp4_to_bf16,
get_safetensor,
)
with import_plugin("transformers", verbose=False):
import transformers
has_mcore = False
with import_plugin("megatron"):
from megatron.core.parallel_state import (
get_expert_tensor_parallel_world_size,
get_tensor_model_parallel_world_size,
)
from megatron.core.ssm.mamba_layer import MambaLayer
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.torch_norm import L2Norm
from megatron.core.transformer.transformer_layer import TransformerLayer
has_mcore = True
class GPTModelImporter:
"""Megatron Core GPTModel HuggingFace Importer.
The Importer is created by `import_mcore_gpt_from_hf` to host attributes
and methods that import a Megatron Core GPTModel from a supported Hugging
Face model.
Args:
model: The Megatron Core GPTModel instance.
pretrained_model_name_or_path: Can be either: the *model id* of a
pretrained model hosted inside a model repo on huggingface.co; or
a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
dtype: The weights data type to export the unquantized layers.
"""
weight_scale_name: str = "weight_scale_inv"
def __init__(
self,
model: torch.nn.Module,
pretrained_model_name_or_path: str,
workspace_dir: str | None = None,
dtype=torch.bfloat16,
dequantize: bool = True,
trust_remote_code: bool = False,
verbose: bool = False,
moe_router_dtype: torch.dtype | None = None,
):
"""Create a GPTModel importer instance."""
self._hf_config = transformers.AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code
)
self.moe_router_dtype = None
if moe_router_dtype == "fp32":
self.moe_router_dtype = torch.float32
elif moe_router_dtype == "fp64":
self.moe_router_dtype = torch.float64
pretrained_model_path = Path(pretrained_model_name_or_path)
if not pretrained_model_path.is_dir():
if workspace_dir is None:
workspace_dir = tempfile.gettempdir()
pretrained_model_path = workspace_dir + "/" + pretrained_model_name_or_path
if dist.get_rank() == 0:
snapshot_download(
repo_id=pretrained_model_name_or_path,
local_dir=pretrained_model_path,
)
dist.barrier()
self.arch = self._hf_config.architectures[0]
self.all_rules = self._populate_rule_book()
self.rules = self.all_rules[self.arch]
self.model = model
self.pretrained_model_path = pretrained_model_path
self.dtype = dtype
self.dequantize = dequantize
self.verbose = verbose
self.disable_tqdm = dist.get_rank() > 0 or verbose
def _populate_rule_book(self):
"""The rule book maps each state_dict key to a Callable."""
all_rules = {}
def _custom_mapping_to_lambda(mapping):
method_map = {
"name_remapping": self._name_remapping,
"qkv_merging": self._qkv_merging,
"gated_mlp_merging": self._gated_mlp_merging,
"grouped_mlp_merging": self._grouped_mlp_merging,
"unpack_name_remapping": self._unpack_name_remapping,
"unpack_name_remapping_gpt_oss": self._unpack_name_remapping_gpt_oss,
}
func = method_map[mapping.func_name]
prefix = mapping.target_name_or_prefix
func_kwargs = mapping.func_kwargs
return lambda m, *args, **kwargs: func(
m, prefix.format(*args), **{**func_kwargs, **kwargs}
)
for arch, mappings in all_mcore_hf_import_mapping.items():
all_rules[arch] = {
k: _custom_mapping_to_lambda(v) if isinstance(v, CustomModuleMapping) else v
for (k, v) in mappings.items()
if isinstance(v, (CustomModuleMapping, bool))
}
return all_rules
def _get_safetensor(self, key, parallel_config: ParallelConfig | None = None):
return get_safetensor(
self.pretrained_model_path, key, parallel_config, dequantize=self.dequantize
)
def _name_remapping(
self,
module,
prefix,
mapping={},
parallel_config: ParallelConfig | None = None,
dtype: torch.dtype | None = None,
is_mtp: bool = False,
):
if is_mtp:
if "backbone" in prefix:
prefix = prefix.replace("backbone", "mtp")
else:
prefix = prefix.replace("model", "mtp")
if dtype is None:
dtype = self.dtype
if isinstance(module, torch.Tensor):
tensor = self._get_safetensor(prefix, parallel_config=parallel_config)
module.data.copy_(tensor)
return
weight = module.state_dict().get("weight", None)
weight_scale = module.state_dict().get("weight_quantizer._scale", None)
state_dict = {}
if weight is None:
raise ValueError(f"{module!s} does not contain weight!")
else:
tensor = self._get_safetensor(prefix + "weight", parallel_config=parallel_config)
if weight_scale is not None:
scale_name = prefix + self.weight_scale_name
if weight_scale.ndim > 0:
scale = self._get_safetensor(scale_name, parallel_config=parallel_config)
else:
scale = self._get_safetensor(scale_name)
scale = scale.to(weight_scale.dtype).to(device=weight_scale.device)
state_dict["weight_quantizer._scale"] = scale
if tensor.shape != weight.shape:
expanded_tensor = torch.zeros(weight.shape, dtype=tensor.dtype)
expanded_tensor[: tensor.shape[0], : tensor.shape[1]] = tensor
tensor = expanded_tensor
state_dict["weight"] = tensor.view(dtype=weight.dtype).to(device=weight.device)
else:
state_dict["weight"] = tensor.to(dtype=dtype).to(device=weight.device)
# Handle the rest of the state_dict.
for key, val in module.state_dict().items():
if key in {"weight", "weight_quantizer._scale"}:
continue
elif "extra_state" in key:
state_dict[key] = val
else:
source_key = mapping.get(key, key)
# A mapping value of None means "skip this key" (keep existing value).
# This is used for fused TE modules where layer_norm_weight is loaded
# separately from a different HF path.
if source_key is None:
state_dict[key] = val
continue
# For bias tensors in ROW_TP layers, don't use parallel config to avoid sharding
# since bias should always be replicated, not sharded
if (
key == "bias"
and parallel_config is not None
and parallel_config.sharding_dim == 1
):
tensor = self._get_safetensor(prefix + source_key, parallel_config=None)
else:
tensor = self._get_safetensor(
prefix + source_key, parallel_config=parallel_config
)
state_dict[key] = tensor.to(dtype=dtype).to(device=val.device)
module.load_state_dict(state_dict)
def _gated_mlp_merging(
self,
module,
prefix,
gate_proj_name="gate_proj",
up_proj_name="up_proj",
parallel_config: ParallelConfig | None = None,
is_mtp: bool = False,
):
if is_mtp:
if "backbone" in prefix:
prefix = prefix.replace("backbone", "mtp")
else:
prefix = prefix.replace("model", "mtp")
weight = module.state_dict().get("weight", None)
weight_scale = module.state_dict().get("weight_quantizer._scale", None)
state_dict = {}
if weight is None:
raise ValueError(f"{module!s} does not contain weight!")
else:
gate_proj = self._get_safetensor(
prefix + gate_proj_name + ".weight", parallel_config=parallel_config
)
up_proj = self._get_safetensor(
prefix + up_proj_name + ".weight", parallel_config=parallel_config
)
tensor = torch.cat((gate_proj, up_proj), dim=0)
if weight_scale is not None:
gate_scale_name = prefix + gate_proj_name + "." + self.weight_scale_name
up_scale_name = prefix + up_proj_name + "." + self.weight_scale_name
if weight_scale.ndim > 0:
gate_scale = self._get_safetensor(gate_scale_name, parallel_config=parallel_config)
up_scale = self._get_safetensor(up_scale_name, parallel_config=parallel_config)
scale = torch.cat((gate_scale, up_scale), dim=0)
else:
scale = self._get_safetensor(gate_scale_name)
# If source model is per tensor, compute a per tensor scale with max.
if scale.ndim > 0:
scale = scale.max(dim=0).max(dim=0)
state_dict["weight_quantizer._scale"] = scale.to(weight_scale.dtype).to(
device=weight_scale.device
)
state_dict["weight"] = tensor.view(weight.dtype).to(device=weight.device)
else:
state_dict["weight"] = tensor.to(self.dtype).to(device=weight.device)
module.load_state_dict(state_dict)
def _grouped_mlp_merging(
self,
module,
prefix,
parallel_config: ParallelConfig | None = None,
is_mtp: bool = False,
init_expert_id: int = 0,
num_local_experts: int = 1,
):
if is_mtp:
if "backbone" in prefix:
prefix = prefix.replace("backbone", "mtp")
else:
prefix = prefix.replace("model", "mtp")
state_dict = module.state_dict()
assert module.num_gemms == num_local_experts, (
"num_gemms must be equal to num_local_experts in TEGroupedMLP"
)
for expert_id in range(init_expert_id, init_expert_id + num_local_experts):
tensor = self._get_safetensor(prefix.format(expert_id) + ".weight")
state_dict[f"weight{expert_id}"] = tensor
# TODO handle weight_scale
module.load_state_dict(state_dict)
def _qkv_merging(
self,
module,
prefix,
q_proj_name="q_proj",
k_proj_name="k_proj",
v_proj_name="v_proj",
parallel_config: ParallelConfig | None = None,
is_mtp: bool = False,
):
if is_mtp:
if "backbone" in prefix:
prefix = prefix.replace("backbone", "mtp")
else:
prefix = prefix.replace("model", "mtp")
config = module.config
hidden_size = config.hidden_size
num_query_groups = config.num_query_groups
head_num = config.num_attention_heads
head_size = config.kv_channels
if parallel_config is not None:
tp_size = get_tensor_model_parallel_world_size()
assert head_num % tp_size == 0
assert num_query_groups % tp_size == 0
head_num = head_num // tp_size
num_query_groups = num_query_groups // tp_size
heads_per_group = head_num // num_query_groups
qkv_total_dim = head_num + 2 * num_query_groups
q_slice = torch.cat(
[
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
state_dict = {}
module_state_dict = module.state_dict()
weight = module_state_dict.get("weight", None)
weight_scale = module_state_dict.get("weight_quantizer._scale", None)
if weight is None:
raise ValueError(f"{module!s} does not contain weight!")
if weight_scale is not None:
q_scale_name = prefix + q_proj_name + "." + self.weight_scale_name
k_scale_name = prefix + k_proj_name + "." + self.weight_scale_name
v_scale_name = prefix + v_proj_name + "." + self.weight_scale_name
if weight_scale.ndim > 0:
q_scale = self._get_safetensor(q_scale_name, parallel_config=parallel_config)
k_scale = self._get_safetensor(k_scale_name, parallel_config=parallel_config)
v_scale = self._get_safetensor(v_scale_name, parallel_config=parallel_config)
weight_scale[q_slice] = q_scale.to(weight_scale.dtype).to(
device=weight_scale.device
)
weight_scale[k_slice] = k_scale.to(weight_scale.dtype).to(
device=weight_scale.device
)
weight_scale[v_slice] = v_scale.to(weight_scale.dtype).to(
device=weight_scale.device
)
else:
q_scale = self._get_safetensor(q_scale_name)
weight_scale = q_scale.to(weight_scale.dtype).to(device=weight_scale.device)
state_dict["weight_quantizer._scale"] = weight_scale
q_proj = self._get_safetensor(
prefix + q_proj_name + ".weight", parallel_config=parallel_config
)
k_proj = self._get_safetensor(
prefix + k_proj_name + ".weight", parallel_config=parallel_config
)
v_proj = self._get_safetensor(
prefix + v_proj_name + ".weight", parallel_config=parallel_config
)
q_proj = q_proj.reshape(-1, head_size, hidden_size)
k_proj = k_proj.reshape(-1, head_size, hidden_size)
v_proj = v_proj.reshape(-1, head_size, hidden_size)
tensor = weight.detach().clone().reshape([qkv_total_dim, head_size, hidden_size])
if weight_scale is not None:
tensor[q_slice] = q_proj.view(dtype=tensor.dtype).to(device=tensor.device)
tensor[k_slice] = k_proj.view(dtype=tensor.dtype).to(device=tensor.device)
tensor[v_slice] = v_proj.view(dtype=tensor.dtype).to(device=tensor.device)
else:
tensor[q_slice] = q_proj.to(dtype=tensor.dtype).to(device=tensor.device)
tensor[k_slice] = k_proj.to(dtype=tensor.dtype).to(device=tensor.device)
tensor[v_slice] = v_proj.to(dtype=tensor.dtype).to(device=tensor.device)
state_dict["weight"] = tensor.reshape(-1, hidden_size)
# Handle bias merging
bias = module_state_dict.get("bias", None)
if bias is not None:
q_bias = self._get_safetensor(
prefix + q_proj_name + ".bias", parallel_config=parallel_config
)
k_bias = self._get_safetensor(
prefix + k_proj_name + ".bias", parallel_config=parallel_config
)
v_bias = self._get_safetensor(
prefix + v_proj_name + ".bias", parallel_config=parallel_config
)
# Reshape separate biases to match the head structure
q_bias = q_bias.reshape(-1, head_size)
k_bias = k_bias.reshape(-1, head_size)
v_bias = v_bias.reshape(-1, head_size)
# Create target bias tensor with the same structure as the fused QKV
bias_tensor = bias.detach().clone().reshape([qkv_total_dim, head_size])
# Merge biases using the same slicing logic as weights
bias_tensor[q_slice] = q_bias.to(dtype=bias_tensor.dtype).to(device=bias_tensor.device)
bias_tensor[k_slice] = k_bias.to(dtype=bias_tensor.dtype).to(device=bias_tensor.device)
bias_tensor[v_slice] = v_bias.to(dtype=bias_tensor.dtype).to(device=bias_tensor.device)
state_dict["bias"] = bias_tensor.reshape(-1)
layer_norm_weight = module_state_dict.get("layer_norm_weight", None)
if layer_norm_weight is not None:
state_dict["layer_norm_weight"] = layer_norm_weight
state_dict["_extra_state"] = None # for TE modules require _extra_state key
module.load_state_dict(state_dict)
def _unpack_name_remapping(
self,
module,
prefix,
layer_type: str,
parallel_config: ParallelConfig | None = None,
is_mtp: bool = False, # no-op: necessary for _import_transformer_layer
):
tensor = self._get_safetensor(prefix, parallel_config=parallel_config)
for idx, sub_module in enumerate(module.children()):
state_dict = {}
linear_module = getattr(sub_module, layer_type)
weight = linear_module.state_dict().get("weight", None)
sub_tensor = tensor[idx]
if weight is None:
raise ValueError(f"{linear_module!s} does not contain weight!")
# TODO (yueshen): Handle weight_scale case
else:
# Transpose to match huggingface format with Mcore format
sub_tensor = sub_tensor.transpose(-1, -2)
state_dict["weight"] = sub_tensor.to(dtype=self.dtype).to(device=weight.device)
for key, val in linear_module.state_dict().items():
if key in {"weight", "weight_quantizer._scale"}:
continue
elif "extra_state" in key:
state_dict[key] = val
linear_module.load_state_dict(state_dict)
def _unpack_name_remapping_gpt_oss(
self,
module,
prefix,
layer_type: str,
parallel_config: ParallelConfig | None = None,
is_mtp: bool = False, # no-op: necessary for _import_transformer_layer
):
tensor_blocks = self._get_safetensor(prefix + "_blocks", parallel_config=parallel_config)
tensor_bias = self._get_safetensor(prefix + "_bias", parallel_config=parallel_config)
tensor_scales = self._get_safetensor(prefix + "_scales", parallel_config=parallel_config)
tensor = dequantize_mxfp4_to_bf16(tensor_blocks, tensor_scales, dtype=self.dtype)
for idx, sub_module in enumerate(module.children()):
state_dict = {}
linear_module = getattr(sub_module, layer_type)
weight = linear_module.state_dict().get("weight", None)
sub_tensor = tensor[idx]
if weight is None:
raise ValueError(f"{linear_module!s} does not contain weight!")
# TODO (yueshen): Handle weight_scale case
else:
if layer_type == "linear_fc1":
# HF checkpoint has interleaved weights, need to de-interleave
# Pattern: [0,2,4,...,5758] -> [0,1,2,...,2879] and [1,3,5,...,5759] -> [2880,2881,...,5759]
height, width = sub_tensor.shape
half_height = height // 2
# Create de-interleaved tensor
deinterleaved_tensor = torch.zeros_like(sub_tensor)
deinterleaved_tensor[:half_height] = sub_tensor[
::2
] # Even indices -> first half
deinterleaved_tensor[half_height:] = sub_tensor[
1::2
] # Odd indices -> second half
sub_tensor = deinterleaved_tensor
state_dict["weight"] = sub_tensor.to(dtype=self.dtype).to(device=weight.device)
for key, val in linear_module.state_dict().items():
if key in {"weight", "weight_quantizer._scale"}:
continue
elif "extra_state" in key:
state_dict[key] = val
elif "bias" in key:
sub_tensor_bias = tensor_bias[idx]
if layer_type == "linear_fc1":
# HF checkpoint has interleaved bias, need to de-interleave
bias_len = sub_tensor_bias.shape[0]
half_bias_len = bias_len // 2
# Create de-interleaved bias tensor
deinterleaved_bias = torch.zeros_like(sub_tensor_bias)
deinterleaved_bias[:half_bias_len] = sub_tensor_bias[
::2
] # Even indices -> first half
deinterleaved_bias[half_bias_len:] = sub_tensor_bias[
1::2
] # Odd indices -> second half
sub_tensor_bias = deinterleaved_bias
state_dict["bias"] = sub_tensor_bias.to(dtype=self.dtype).to(device=val.device)
linear_module.load_state_dict(state_dict)
def _import_mamba_layer(self, layer, layer_id, layer_pbar):
layer_pbar.set_description("Importing Mamba layer")
if not isinstance(layer.norm, IdentityOp):
self.rules["norm"](layer.norm, layer_id)
self.rules["mixer_norm"](layer.mixer.norm, layer_id)
self.rules["A_log"](layer.mixer.A_log, layer_id)
self.rules["D"](layer.mixer.D, layer_id)
self.rules["dt_bias"](layer.mixer.dt_bias, layer_id)
self.rules["conv1d"](layer.mixer.conv1d, layer_id)
self.rules["in_proj"](layer.mixer.in_proj, layer_id)
self.rules["out_proj"](layer.mixer.out_proj, layer_id)
# TE spec: layer norm is fused into in_proj (TELayerNormColumnParallelLinear).
# Load the fused layer_norm_weight from the HF norm path.
if (
isinstance(layer.norm, IdentityOp)
and hasattr(layer.mixer.in_proj, "layer_norm_weight")
and "fused_norm" in self.rules
):
self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id)
def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False):
if not isinstance(layer.input_layernorm, IdentityOp):
self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp)
attention = layer.self_attention
if not isinstance(attention, IdentityOp):
if "MLASelfAttention" in str(type(attention)):
if hasattr(attention, "linear_q_proj"):
layer_pbar.set_description("Importing MLA (without q LoRA)")
self.rules["linear_q_proj"](attention.linear_q_proj, layer_id, is_mtp=is_mtp)
else:
layer_pbar.set_description("Importing MLA (with q LoRA)")
self.rules["linear_q_down_proj"](
attention.linear_q_down_proj, layer_id, is_mtp=is_mtp
)
self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp)
self.rules["linear_q_up_proj"](
attention.linear_q_up_proj, layer_id, is_mtp=is_mtp
)
self.rules["linear_kv_down_proj"](
attention.linear_kv_down_proj, layer_id, is_mtp=is_mtp
)
self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id, is_mtp=is_mtp)
self.rules["linear_kv_up_proj"](
attention.linear_kv_up_proj, layer_id, is_mtp=is_mtp
)
self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp)
else:
layer_pbar.set_description("Importing GQA/MHA")
if attention.q_layernorm is not None and not isinstance(
attention.q_layernorm, (IdentityOp, L2Norm)
):
self.rules["q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp)
self.rules["k_layernorm"](attention.k_layernorm, layer_id, is_mtp=is_mtp)
self.rules["linear_qkv"](attention.linear_qkv, layer_id, is_mtp=is_mtp)
self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp)
if getattr(attention.core_attention, "softmax_offset", None) is not None:
self.rules["softmax_offset"](
attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp
)
# TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear).
# Load the fused layer_norm_weight from the HF norm path.
if (
isinstance(layer.input_layernorm, IdentityOp)
and hasattr(attention, "linear_qkv")
and hasattr(attention.linear_qkv, "layer_norm_weight")
and "fused_norm" in self.rules
):
self.rules["fused_norm"](
attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp
)
if not isinstance(layer.pre_mlp_layernorm, IdentityOp):
self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp)
if not isinstance(layer.mlp, IdentityOp):
if "MoE" in str(type(layer.mlp)):
layer_pbar.set_description(
f"Importing MoE with moe_router_dtype: {self.moe_router_dtype}"
)
self.rules["router"](
layer.mlp.router, layer_id, dtype=self.moe_router_dtype, is_mtp=is_mtp
)
if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None:
self.rules["fc1_latent_proj"](
layer.mlp.fc1_latent_proj, layer_id, is_mtp=is_mtp
)
if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None:
self.rules["fc2_latent_proj"](
layer.mlp.fc2_latent_proj, layer_id, is_mtp=is_mtp
)
if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None:
layer_pbar.set_description("Importing MoE shared experts")
fc1 = layer.mlp.shared_experts.linear_fc1
fc2 = layer.mlp.shared_experts.linear_fc2
self.rules["shared_experts.linear_fc1"](fc1, layer_id, is_mtp=is_mtp)
self.rules["shared_experts.linear_fc2"](fc2, layer_id, is_mtp=is_mtp)
if not self.rules.get("use_packed_local_experts", False): # Import local experts
experts = layer.mlp.experts
if hasattr(experts, "local_experts"):
for local_expert_id, expert in tqdm(
enumerate(layer.mlp.experts.local_experts),
desc="Importing MoE local experts",
leave=False,
disable=self.disable_tqdm,
):
expert_id = layer.mlp.local_expert_indices[local_expert_id]
fc1 = expert.linear_fc1
fc2 = expert.linear_fc2
self.rules["local_experts.linear_fc1"](
fc1, layer_id, expert_id, is_mtp=is_mtp
)
self.rules["local_experts.linear_fc2"](
fc2, layer_id, expert_id, is_mtp=is_mtp
)
else: # Slice TEGroupedMLP
layer_pbar.set_description("Importing MoE grouped local experts")
num_local_experts = experts.num_local_experts
num_global_experts = experts.config.num_moe_experts
assert num_local_experts == num_global_experts, (
"num_local_experts must be equal to num_global_experts during MoE import"
)
init_index = 0
self.rules["experts.linear_fc1"](
experts.linear_fc1,
layer_id,
init_expert_id=init_index,
num_local_experts=num_local_experts,
is_mtp=is_mtp,
)
self.rules["experts.linear_fc2"](
experts.linear_fc2,
layer_id,
init_expert_id=init_index,
num_local_experts=num_local_experts,
is_mtp=is_mtp,
)
# We only support either EP or ETP for now
elif get_expert_tensor_parallel_world_size() > 1:
# ETP supports for packed MoE
# ETP is not supported for gpt-oss model
if self.arch in ["GptOssForCausalLM"]:
raise ValueError("ETP is not supported for gpt-oss model")
self.rules["local_experts.linear_fc1_etp"](
layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp
)
self.rules["local_experts.linear_fc2_etp"](
layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp
)
else:
# EP supports for packed MoE
self.rules["local_experts.linear_fc1_ep"](
layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp
)
self.rules["local_experts.linear_fc2_ep"](
layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp
)
else:
layer_pbar.set_description("Importing MLP")
self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp)
self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp)
# TE spec: pre_mlp_layernorm is fused into linear_fc1
# (TELayerNormColumnParallelLinear).
# Load the fused layer_norm_weight from the HF norm path.
if (
isinstance(layer.pre_mlp_layernorm, IdentityOp)
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
and "fused_norm" in self.rules
):
self.rules["fused_norm"](
layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp
)
def _import_state_dict(self):
model = self.model
layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm)
# Embedding
if hasattr(model, "embedding"):
layer_pbar.set_description("Importing word embedding")
self.rules["word_embeddings"](model.embedding.word_embeddings)
# Decoder layers
for layer in layer_pbar:
layer_pbar.set_description(f"Importing Decoder layer {layer.layer_number}")
layer_id = layer.layer_number - 1
if isinstance(layer, MambaLayer):
self._import_mamba_layer(layer, layer_id, layer_pbar)
elif isinstance(layer, TransformerLayer):
self._import_transformer_layer(layer, layer_id, layer_pbar)
if self.verbose:
print(
"{:3}/{:3} completes importing layer {:3}.".format(
dist.get_rank(), dist.get_world_size(), layer_id
),
flush=True,
)
# Final layernorm
if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm:
self.rules["final_layernorm"](model.decoder.final_layernorm)
if hasattr(model.decoder, "final_norm") and model.decoder.final_norm:
self.rules["final_norm"](model.decoder.final_norm)
# Output layer
if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights:
self.rules["output_layer"](model.output_layer)
# MTP
if hasattr(model, "mtp"):
layer_pbar.set_description("Importing MTP")
if len(model.mtp.layers) == 1: # Repeated MTP
layer_id = 0 # reset layer_id for repeated MTP
mtp = model.mtp.layers[0]
self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id)
self.rules["mtp.enorm"](mtp.enorm, layer_id)
self.rules["mtp.hnorm"](mtp.hnorm, layer_id)
mtp_model_layers = mtp.mtp_model_layer.layers
for mtp_model_layer in mtp_model_layers:
if isinstance(mtp_model_layer, TransformerLayer):
self._import_transformer_layer(
mtp_model_layer, layer_id, layer_pbar, is_mtp=True
)
else:
raise ValueError(
f"Unsupported layer type during MTP import: {type(mtp_model_layer)}.\n"
"Only TransformerLayer is supported."
)
layer_id += 1
else: # non-repeated MTP
# MTP is the last layer in DeepSeek V3/R1
layer_id += 1
for mtp in model.mtp.layers:
self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id)
self.rules["mtp.enorm"](mtp.enorm, layer_id)
self.rules["mtp.hnorm"](mtp.hnorm, layer_id)
self.rules["mtp.input_layernorm"](
mtp.decoder.layers[0].input_layernorm, layer_id
)
if hasattr(mtp.decoder.layers[0].self_attention, "linear_q_proj"):
self.rules["mtp.linear_q_proj"](
mtp.decoder.layers[0].self_attention.linear_q_proj, layer_id
)
else:
self.rules["mtp.linear_q_down_proj"](
mtp.decoder.layers[0].self_attention.linear_q_down_proj, layer_id
)
self.rules["mtp.linear_q_layernorm"](
mtp.decoder.layers[0].self_attention.q_layernorm, layer_id
)
self.rules["mtp.linear_q_up_proj"](
mtp.decoder.layers[0].self_attention.linear_q_up_proj, layer_id
)
self.rules["mtp.linear_kv_down_proj"](
mtp.decoder.layers[0].self_attention.linear_kv_down_proj, layer_id
)
self.rules["mtp.linear_kv_layernorm"](
mtp.decoder.layers[0].self_attention.kv_layernorm, layer_id
)
self.rules["mtp.linear_kv_up_proj"](
mtp.decoder.layers[0].self_attention.linear_kv_up_proj, layer_id
)
self.rules["mtp.linear_proj"](
mtp.decoder.layers[0].self_attention.linear_proj, layer_id
)
self.rules["mtp.pre_mlp_layernorm"](
mtp.decoder.layers[0].pre_mlp_layernorm, layer_id
)
self.rules["mtp.router"](mtp.decoder.layers[0].mlp.router, layer_id)
self.rules["mtp.shared_experts.linear_fc1"](
mtp.decoder.layers[0].mlp.shared_experts.linear_fc1, layer_id
)
self.rules["mtp.shared_experts.linear_fc2"](
mtp.decoder.layers[0].mlp.shared_experts.linear_fc2, layer_id
)
for expert_id, expert in tqdm(
enumerate(mtp.decoder.layers[0].mlp.experts.local_experts),
desc="Importing MoE local experts",
leave=False,
disable=self.disable_tqdm,
):
self.rules["mtp.local_experts.linear_fc1"](
expert.linear_fc1, layer_id, expert_id
)
self.rules["mtp.local_experts.linear_fc2"](
expert.linear_fc2, layer_id, expert_id
)