-
Notifications
You must be signed in to change notification settings - Fork 381
Expand file tree
/
Copy pathlayer_utils.py
More file actions
executable file
·1944 lines (1642 loc) · 74.7 KB
/
layer_utils.py
File metadata and controls
executable file
·1944 lines (1642 loc) · 74.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for model_config export.
Some of the logics in this file are empirical and needs constant update if exceptions occur.
"""
from warnings import warn
import torch
import torch.nn as nn
try:
from transformers.activations import ACT2FN
except Exception:
warn("Cannot find transformers package. Hugginface modules cannot be exported.")
from modelopt.torch.utils import distributed as dist
from modelopt.torch.utils import import_plugin
from ..quantization.nn import SequentialQuantizer, TensorQuantizer
from .hf_config_map import HF_CONFIG_MAP
from .mcore_config_map import MCORE_CONFIG_MAP
from .model_config import (
LAYERNORM_DEFAULT,
LAYERNORM_RMS,
LINEAR_COLUMN,
LINEAR_GROUP,
LINEAR_ROW,
QUANTIZATION_FP8,
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
AttentionConfig,
ConvConfig,
DecoderLayerConfig,
EmbeddingConfig,
ExpertConfig,
LayernormConfig,
LinearActConfig,
LinearConfig,
MedusaHeadConfig,
MLPConfig,
MOEConfig,
QKVConfig,
RecurrentConfig,
RelativeAttentionTableConfig,
RgLruConfig,
)
from .model_config_utils import pad_weights
from .postprocess import view_as_float8_e4m3fn_if_needed, view_as_uint8_if_needed
from .quant_utils import (
get_activation_scaling_factor,
get_kv_cache_bias,
get_kv_cache_dtype,
get_kv_cache_scaling_factor,
get_prequant_scaling_factor,
get_quantization_format,
get_weight_block_size,
get_weight_scaling_factor,
get_weight_scaling_factor_2,
preprocess_linear_fusion,
)
has_mcore = False
with import_plugin("megatron", verbose=False):
from megatron.core.transformer.module import MegatronModule
has_mcore = True
def get_experts_list(module: torch.nn.Module, model_type: str):
"""Returns list of grouped experts by linear name for given module."""
experts_list = []
# Define linear layer names for different model types
if "mixtralforcausallm" in model_type:
linear_names = ["w1", "w2", "w3"]
elif any(
qwen_variant in model_type
for qwen_variant in [
"qwenmoeforcausallm",
"qwen2moeforcausallm",
"qwen3moeforcausallm",
"qwen3nextforcausallm",
]
):
linear_names = ["gate_proj", "down_proj", "up_proj"]
else:
raise NotImplementedError(f" {model_type} not supported")
# Common logic for all supported model types
experts_list.extend(
[
[_get_expert_attr(module.experts, i, linear_name) for i in range(len(module.experts))]
for linear_name in linear_names
]
)
return experts_list
def get_dtype(model):
"""Returns the default dtype of the model."""
for weight in model.parameters():
if torch.is_floating_point(weight):
return weight.dtype
return None
def check_model_compatibility(module_list: list[nn.Module]) -> tuple[bool, bool, bool]:
"""Returns whether the list of modules is compatible with the export logic.
And if positional embedding and embedding layernorm exists.
We assumes the model to be assembled with one or two embedding layers,
a ModuleList of transformer decoders,
and a final layernorm with optional embedding layernorm.
Otherwise it will not be supported.
"""
num_embeddings = 0
num_module_list = 0
num_layer_norm = 0
for module in module_list:
if is_embedding(module):
num_embeddings += 1
elif is_decoder_list(module):
num_module_list += 1
elif is_layernorm(module):
num_layer_norm += 1
return (
1 <= num_embeddings <= 2 and num_module_list == 1 and 1 <= num_layer_norm <= 2,
num_embeddings > 1,
num_layer_norm > 1,
)
def get_transformer_layers(model: nn.Module) -> list[nn.Module]:
"""Returns the root module of the transformer model."""
if "GPTModel" in type(model).__name__:
# mcore models
layers = []
if hasattr(model, "embedding"):
layers = layers + list(model.embedding.children())
layers = layers + list(model.decoder.children())
if hasattr(model, "output_layer"):
layers.append(model.output_layer)
return layers
if hasattr(model, "glm"):
model = model.glm
if hasattr(model, "transformer"):
# This is a LMHead model
# Add lm_head to be processed along with transformer layers
modules = []
for m in model.transformer.children():
# QwenVL's visual encoder name as 'VisionTransformer' has no `layers`.
if (
"Transformer" in type(m).__name__
and hasattr(m, "layers")
and is_decoder_list(m.layers)
):
modules.append(m.layers)
modules.append(m.final_layernorm)
else:
modules.append(m)
if hasattr(model, "lm_head"):
modules += [model.lm_head]
return modules
if hasattr(model, "model"):
# LLAMA, InternLM2
modules = list(model.model.children())
# LLAMA
if hasattr(model, "lm_head"):
modules += [model.lm_head]
# InternLM2
elif hasattr(model, "output"):
modules += [model.output]
return modules
return list(model.children())
def is_linear(module: nn.Module) -> bool:
"""Returns whether the module is a linear layer."""
return any(k in type(module).__name__ for k in ["Linear", "Conv1D", "NormHead"])
def is_conv(module: nn.Module) -> bool:
"""Returns whether the module is a convolutional layer."""
return "Conv" in type(module).__name__
def is_embedding(module: nn.Module) -> bool:
"""Returns whether the module is an embedding layer."""
module_type_name = type(module).__name__
return (
"Embedding" in module_type_name
and "Rotary" not in module_type_name
and "PhiImage" not in module_type_name
and "Phi3Image" not in module_type_name
)
def build_embedding_config(module: nn.Module, normalization_constant: float = 1) -> EmbeddingConfig:
"""Builds the embedding config from the module."""
assert is_embedding(module)
world_size = dist.size()
rank = dist.rank()
# Special case for chatglm
if hasattr(module, "word_embeddings"):
module = module.word_embeddings
weight = module.weight
normalized_weight = weight * normalization_constant
if "Parallel" in type(module).__name__:
local_weight = normalized_weight
else:
padded_weight = pad_weights(normalized_weight, dist.size())
local_weight = torch.chunk(padded_weight, world_size, dim=0)[rank]
return EmbeddingConfig(weight=local_weight)
def is_layernorm(module: nn.Module) -> bool:
"""Returns whether the module is a layernorm layer."""
module_name = type(module).__name__
return any(norm in module_name for norm in ["LayerNorm", "RMSNorm"])
def build_layernorm_config(module: nn.Module) -> LayernormConfig:
"""Builds the layernorm config from the module."""
assert is_layernorm(module)
layernorm_type = LAYERNORM_DEFAULT
if "RMS" in type(module).__name__:
layernorm_type = LAYERNORM_RMS
weight = module.weight
def _weights_plus_one(module):
if any(
name in type(module).__name__
for name in ["LayerNorm1P", "GemmaRMSNorm", "Gemma2RMSNorm", "Gemma3RMSNorm"]
):
return True
return bool(hasattr(module, "zero_centered_gamma") and module.zero_centered_gamma)
if _weights_plus_one(module):
# megatron layernorm's weight needs to be updated.
weight = weight.float() + 1.0
config = LayernormConfig(
weight=weight,
bias=(module.bias if hasattr(module, "bias") and module.bias is not None else None),
layernorm_type=layernorm_type,
)
# TODO: handle the nemo llama eps config.
for eps_key in ["eps", "variance_epsilon"]:
if hasattr(module, eps_key):
config.eps = getattr(module, eps_key)
break
return config
def is_decoder_list(module: nn.Module) -> bool:
"""Returns whether the module is a decoder list."""
return type(module) is nn.ModuleList
def is_attention(module: nn.Module) -> bool:
"""Returns whether the module is an attention layer."""
return "Attention" in type(module).__name__
def is_mlp(module: nn.Module) -> bool:
"""Returns whether the module is an MLP layer."""
return any(key in type(module).__name__.upper() for key in ("MLP", "T5DENSE"))
def is_moe(module: nn.Module) -> bool:
"""Returns whether the module is an MOE layer."""
name = type(module).__name__.lower()
# Auto-detect common MoE patterns
if name.endswith("sparsemoeblock") or "moelayer" in name:
return True
# Explicit matches for non-standard naming
return any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn"])
def is_quantlinear(module: nn.Module) -> bool:
"""Returns whether the module is a quantized linear layer."""
name = type(module).__name__
return (
any(
keyword in name
for keyword in ["QuantLinear", "QuantCompressedLinear", "QuantFP8Linear"]
)
and "lora" not in name.lower()
and "ds_kernel" not in name.lower()
)
def dup_kv_weight(v: torch.Tensor, head_size: int, num_head: int, tp_size: int) -> torch.Tensor:
"""Repeat kv heads if tp_size > num_kv_heads."""
assert tp_size % num_head == 0
reps = tp_size // num_head
is_1d = len(v.shape) == 1
v_1 = v.size(-1) if not is_1d else 1
v = v.view(-1, head_size, v_1)
v = v.repeat_interleave(reps, dim=0)
v = v.view(-1, v_1) if not is_1d else v.reshape(-1)
return v.contiguous()
def build_qkv(
qkv_modules: list[nn.Module],
model_metadata_config,
ext_config: DecoderLayerConfig = None,
tp_size: int = 1,
) -> QKVConfig:
"""Converts the qkv modules to the config."""
config = QKVConfig()
q_bias = None
k_bias = None
v_bias = None
block_size = get_weight_block_size(qkv_modules[0])
num_heads = ext_config.num_attention_heads
training_tp = model_metadata_config["training_tensor_parallel"]
if len(qkv_modules) == 1:
# QKV layers combined as a single module, e.g. gpt
qkv_module = qkv_modules[0]
assert ext_config is not None, "ext_config is None"
num_kv_heads = ext_config.num_kv_heads
if "ColumnParallelLinear" in type(qkv_module).__name__:
# For Megatron-core model, num_kv_heads/num_attention_heads is the first dimension of QKV
model_metadata_config["head_is_first_dim"] = True
qkv_weight = qkv_module.weight
if (
type(qkv_module).__name__ == "Conv1D"
and not hasattr(qkv_module, "input_quantizer")
and not hasattr(qkv_module, "output_quantizer")
):
# For unquantized nn.Conv1D, the weights are transposed compared with the nn.Linear
qkv_weight = qkv_weight.T
# Handle the case that num_kv_heads/num_attention_heads is the first dimension of QKV.
# This logic covers MQA and GQA as well.
keep_channel_order = not model_metadata_config.get("head_is_first_dim", False)
q_weight, k_weight, v_weight = _split_fused_qkv_weight_and_scaling(
qkv_weight,
num_heads,
num_kv_heads,
training_tp,
keep_channel_order,
)
qkv_activation_scaling_factor = get_activation_scaling_factor(qkv_module)
q_activation_scaling_factor = qkv_activation_scaling_factor
k_activation_scaling_factor = qkv_activation_scaling_factor
v_activation_scaling_factor = qkv_activation_scaling_factor
qkv_weight_scaling_factor = get_weight_scaling_factor(qkv_module)
if qkv_weight_scaling_factor is not None and qkv_weight_scaling_factor.numel() != 1:
# SQ and AWQ case
(
q_weight_scaling_factor,
k_weight_scaling_factor,
v_weight_scaling_factor,
) = _split_fused_qkv_weight_and_scaling(
qkv_weight_scaling_factor,
num_heads,
num_kv_heads,
training_tp,
keep_channel_order,
)
else:
q_weight_scaling_factor = qkv_weight_scaling_factor
k_weight_scaling_factor = qkv_weight_scaling_factor
v_weight_scaling_factor = qkv_weight_scaling_factor
# bias
if qkv_module.bias is not None:
q_bias, k_bias, v_bias = _split_fused_qkv_weight_and_scaling(
qkv_module.bias,
num_heads,
num_kv_heads,
training_tp,
keep_channel_order,
)
q_weight_scaling_factor_2 = k_weight_scaling_factor_2 = v_weight_scaling_factor_2 = (
get_weight_scaling_factor_2(qkv_module)
)
q_prequant_scaling_factor = k_prequant_scaling_factor = v_prequant_scaling_factor = (
get_prequant_scaling_factor(qkv_module)
)
q_quantization = k_quantization = v_quantization = get_quantization_format(qkv_module)
elif len(qkv_modules) == 3:
preprocess_linear_fusion(qkv_modules)
# Separate QKV layers
q_weight = qkv_modules[0].weight
q_activation_scaling_factor = get_activation_scaling_factor(qkv_modules[0])
q_weight_scaling_factor = get_weight_scaling_factor(qkv_modules[0])
q_quantization = get_quantization_format(qkv_modules[0])
k_weight = qkv_modules[1].weight
k_activation_scaling_factor = get_activation_scaling_factor(qkv_modules[1])
k_weight_scaling_factor = get_weight_scaling_factor(qkv_modules[1])
k_quantization = get_quantization_format(qkv_modules[1])
v_weight = qkv_modules[2].weight
v_activation_scaling_factor = get_activation_scaling_factor(qkv_modules[2])
v_weight_scaling_factor = get_weight_scaling_factor(qkv_modules[2])
v_quantization = get_quantization_format(qkv_modules[2])
q_weight_scaling_factor_2 = get_weight_scaling_factor_2(qkv_modules[0])
k_weight_scaling_factor_2 = get_weight_scaling_factor_2(qkv_modules[1])
v_weight_scaling_factor_2 = get_weight_scaling_factor_2(qkv_modules[2])
q_prequant_scaling_factor = get_prequant_scaling_factor(qkv_modules[0])
k_prequant_scaling_factor = get_prequant_scaling_factor(qkv_modules[1])
v_prequant_scaling_factor = get_prequant_scaling_factor(qkv_modules[2])
if hasattr(qkv_modules[0], "bias"):
q_bias = qkv_modules[0].bias
if hasattr(qkv_modules[1], "bias"):
k_bias = qkv_modules[1].bias
if hasattr(qkv_modules[2], "bias"):
v_bias = qkv_modules[2].bias
else:
raise NotImplementedError(f"QKV modules format {qkv_modules} not supported")
# derive num_kv_heads
head_size = q_weight.size(0) // num_heads
num_kv_heads = ext_config.num_kv_heads or k_weight.size(0) // head_size
if tp_size > num_kv_heads:
if any(
t is not None and t.numel() > 1
for t in [
k_activation_scaling_factor,
k_weight_scaling_factor_2,
v_activation_scaling_factor,
v_weight_scaling_factor_2,
]
):
# TODO(oargov): handle cases with biases / scales
raise NotImplementedError(
"Duplicating KV heads for KV with non-scalar scales and/or biases is not supported"
)
# duplicate kv heads as needed
k_weight = dup_kv_weight(k_weight, head_size, num_kv_heads, tp_size)
v_weight = dup_kv_weight(v_weight, head_size, num_kv_heads, tp_size)
if k_weight_scaling_factor is not None and k_weight_scaling_factor.numel() > 1:
if len(k_weight_scaling_factor.shape) == 1:
raise NotImplementedError(
"Duplicating KV heads per-channel scales is not supported"
)
k_weight_scaling_factor = dup_kv_weight(
k_weight_scaling_factor, head_size, num_kv_heads, tp_size
)
if v_weight_scaling_factor is not None and v_weight_scaling_factor.numel() > 1:
if len(v_weight_scaling_factor.shape) == 1:
raise NotImplementedError(
"Duplicating KV heads per-channel scales is not supported"
)
v_weight_scaling_factor = dup_kv_weight(
v_weight_scaling_factor, head_size, num_kv_heads, tp_size
)
if k_bias is not None:
k_bias = dup_kv_weight(k_bias, head_size, num_kv_heads, tp_size)
if v_bias is not None:
v_bias = dup_kv_weight(v_bias, head_size, num_kv_heads, tp_size)
config.q = LinearConfig(linear_type=LINEAR_COLUMN)
config.q.weight = q_weight
config.q.bias = q_bias if q_bias is not None else None
config.q.activation_scaling_factor = q_activation_scaling_factor
config.q.weights_scaling_factor = q_weight_scaling_factor
config.q.weights_scaling_factor_2 = q_weight_scaling_factor_2
config.q.prequant_scaling_factor = q_prequant_scaling_factor
config.q.awq_block_size = block_size
config.q.quantization = q_quantization
config.k = LinearConfig(linear_type=LINEAR_COLUMN)
config.k.weight = k_weight
config.k.bias = k_bias if k_bias is not None else None
config.k.activation_scaling_factor = k_activation_scaling_factor
config.k.weights_scaling_factor = k_weight_scaling_factor
config.k.weights_scaling_factor_2 = k_weight_scaling_factor_2
config.k.prequant_scaling_factor = k_prequant_scaling_factor
config.k.awq_block_size = block_size
config.k.quantization = k_quantization
config.v = LinearConfig(linear_type=LINEAR_COLUMN)
config.v.weight = v_weight
config.v.bias = v_bias if v_bias is not None else None
config.v.activation_scaling_factor = v_activation_scaling_factor
config.v.weights_scaling_factor = v_weight_scaling_factor
config.v.weights_scaling_factor_2 = v_weight_scaling_factor_2
config.v.prequant_scaling_factor = v_prequant_scaling_factor
config.v.awq_block_size = block_size
config.v.quantization = v_quantization
if not ext_config.attention_head_size:
ext_config.attention_head_size = config.q.weight.shape[0] * training_tp // num_heads
return config
def build_linear_config(module: nn.Module, linear_type: str) -> LinearConfig:
"""Builds the linear config for the module."""
if has_mcore and not isinstance(module, MegatronModule):
# Check only for HF model, not Mcore model
assert is_linear(module)
torch_weight = module.weight
if "NormHead" in type(module).__name__:
torch_weight = torch.nn.functional.normalize(torch_weight)
elif "Conv1D" in type(module).__name__ and not (
hasattr(module, "input_quantizer") or hasattr(module, "output_quantizer")
):
# Transpose Conv1D weights to linear unless it has been transposed by the quantization.
torch_weight = torch_weight.T
weight = torch_weight
config = LinearConfig(linear_type=linear_type)
config.weight = weight
if hasattr(module, "bias") and module.bias is not None:
config.bias = module.bias
config.activation_scaling_factor = get_activation_scaling_factor(module)
config.weights_scaling_factor = get_weight_scaling_factor(module)
config.weights_scaling_factor_2 = get_weight_scaling_factor_2(module)
config.prequant_scaling_factor = get_prequant_scaling_factor(module)
config.awq_block_size = get_weight_block_size(module)
config.quantization = get_quantization_format(module)
return config
def build_fused_linear_config(modules: list[nn.Module], linear_type: str) -> LinearConfig:
"""Returns a fused linear config from a list of modules."""
assert linear_type == LINEAR_COLUMN, "Only support column fuse"
preprocess_linear_fusion(modules)
config = build_linear_config(modules[0], linear_type=linear_type)
config.weight = torch.cat([module.weight for module in modules], dim=0)
if config.weights_scaling_factor is not None and config.weights_scaling_factor.numel() != 1:
config.weights_scaling_factor = torch.cat(
[get_weight_scaling_factor(module) for module in modules], dim=0
)
return config
def build_attention_config(
module: nn.Module,
model_metadata_config,
ext_config: DecoderLayerConfig = None,
tp_size: int = 1,
) -> AttentionConfig:
"""Builds the attention config from the module."""
assert is_attention(module)
config = AttentionConfig()
if hasattr(module, "rotary_dim"):
config.rotary_dim = module.rotary_dim
if hasattr(module, "clip_qkv"):
config.clip_qkv = module.clip_qkv
qkv_modules = []
q = None
k = None
v = None
for name, layer in module.named_children():
if is_linear(layer):
if _is_qkv(name):
qkv_modules.append(layer)
elif "q" in name:
q = layer
elif "k" in name:
k = layer
elif "v" in name:
v = layer
else:
# The dense layer
config.dense = build_linear_config(layer, LINEAR_ROW)
elif is_layernorm(layer):
if "q" in name.lower():
config.q_layernorm = build_layernorm_config(layer)
elif "k" in name.lower():
config.k_layernorm = build_layernorm_config(layer)
else:
raise NotImplementedError(f"{name}: {layer} not recognized")
elif (
"model_type" in model_metadata_config
and model_metadata_config["model_type"] == "t5"
and not isinstance(layer, TensorQuantizer)
):
config.rel_attn_table = RelativeAttentionTableConfig(weight=layer.weight.T)
if not qkv_modules:
assert q
assert k
assert v
qkv_modules = [q, k, v]
for layer in qkv_modules:
# Add the missing zero bias for Whisper model for export purpose
if layer.bias is None and q.bias is not None:
layer.bias = torch.nn.Parameter(
torch.zeros(layer.weight.size(1), device=layer.weight.device),
requires_grad=True,
)
print("Add missing zero bias for qkv modules for export purpose")
config.qkv = build_qkv(qkv_modules, model_metadata_config, ext_config, tp_size=tp_size)
config.k_cache_scaling_factor, config.v_cache_scaling_factor = get_kv_cache_scaling_factor(
module
)
config.k_cache_bias, config.v_cache_bias = get_kv_cache_bias(module)
if config.k_cache_scaling_factor is not None:
assert config.v_cache_scaling_factor is not None
config.kv_cache_dtype = get_kv_cache_dtype(module)
return config
def _is_qkv(name) -> bool:
return all(k in name for k in ["q", "k", "v"]) or "W_pack" in name or "c_attn" in name
def _get_hidden_act(act_func) -> str:
"""Returns the name of the hidden activation function based on ACT2FN."""
if isinstance(act_func, str):
return act_func
# Falcon activation, "nn.GELU" is equivalent to "gelu" in ACT2FN
if isinstance(act_func, nn.GELU):
return "gelu"
if hasattr(act_func, "func") and act_func.func == nn.functional.gelu:
return "gelu"
for name, func in ACT2FN.items():
# TRT LLM uses "squared-relu" activation keyword.
if name == "relu2":
name = "squared-relu"
if isinstance(func, tuple):
if isinstance(act_func, func[0]):
return name
elif isinstance(act_func, func):
return name
return act_func.__name__
def build_mlp_config(
module: nn.Module,
decoder_type,
hidden_act: str | None = None,
merge_gate_fc: bool = False,
) -> MLPConfig:
"""Builds the MLP config for the module."""
assert is_mlp(module)
config = MLPConfig(merge_gate_fc=merge_gate_fc)
def _split_gate_from_fc(decoder_type, module, fc_name, fc_layer):
if (
"ColumnParallelLinear" in type(fc_layer).__name__
and hasattr(module.config, "gated_linear_unit")
and module.config.gated_linear_unit
):
return True
if decoder_type != "gpt":
return False
return bool("dense_h_to_4h" in fc_name and "dense_h_to_4h_2" not in fc_name)
# TODO: We may want to refactor these keywords/model mapping
fc_keywords = {
"c_fc", # gpt2
"fc_in", # gptj
"gate_proj", # llama, baichuan, recurrentgemma, deepseek
"dense_h_to_4h", # falcon, chatglm, bloom
"linear_fc1",
"w2", # qwen
"fc1", # phi, gemma, whisper
"gate_up_proj", # phi
"wi_0", # t5
"wi", # t5
"c_fc_0", # exaone
}
proj_keywords = {
"c_proj", # gpt2, qwen, exaone
"fc_out", # gptj
"dense_4h_to_h", # falcon, chatglm, bloom
"4h_to_h",
"down_proj", # llama, baichuan, mpt, phi, recurrentgemma, nemotron, deepseek
"linear_fc2",
"proj",
"fc2", # phi, gemma, whisper
"wo", # t5
}
gate_keywords = {
"up_proj", # llama, baichuan, recurrentgemma, deepseek
"dense_h_to_4h_2",
"w1", # qwen
"wi_1", # t5
"c_fc_1", # exaone
}
# Arctic (llama-based MoE, decoder_type is "llama") has MLP keyword conflicts with Qwen
# Arctic's residual MLP use w1 for fc, w2 for proj, w3 for gate
if type(module).__name__ in ["ArcticMLP", "InternLM2MLP"]:
fc_keywords.discard("w2")
gate_keywords.discard("w1")
fc_keywords.add("w1")
proj_keywords.add("w2")
gate_keywords.add("w3")
if decoder_type == "mpt":
fc_keywords.add("up_proj")
gate_keywords.discard("up_proj")
if type(module).__name__ in [
"TLGv4MLP",
"Phi3SmallMLP",
"NemotronMLP",
]: # for TLGv4ForCausalLM
fc_keywords.add("up_proj")
gate_keywords.discard("up_proj")
fc_linear: nn.Module = None
gate_linear: nn.Module = None
proj_linear: nn.Module = None
for name, layer in module.named_children():
if is_linear(layer):
if any(keyword == name for keyword in fc_keywords):
fc_linear = layer
elif any(keyword == name for keyword in gate_keywords):
gate_linear = layer
elif any(keyword == name for keyword in proj_keywords):
proj_linear = layer
# TensorRT-LLM may choose to merge gate and fc during engine building.
if (
gate_linear is not None
and fc_linear is not None
and (
merge_gate_fc
or get_quantization_format(module) in [QUANTIZATION_FP8, QUANTIZATION_NVFP4]
)
):
preprocess_linear_fusion([fc_linear, gate_linear])
if fc_linear is not None:
weight_quantizer = None
if hasattr(fc_linear, "weight_quantizer"):
weight_quantizer = fc_linear.weight_quantizer
if isinstance(weight_quantizer, SequentialQuantizer):
weight_quantizer = weight_quantizer[0]
# swap fused fc and gate
if decoder_type in ["chatglm", "phi3"]:
weights = torch.chunk(fc_linear.weight, 2, dim=0)
weights = (weights[1], weights[0])
fc_linear.weight.data = torch.cat(weights, dim=0)
if (
weight_quantizer is not None
and weight_quantizer.is_enabled
and weight_quantizer.amax.numel() != 1
):
amax_chunks = torch.chunk(weight_quantizer.amax, 2, dim=0)
weight_quantizer.amax = torch.cat([amax_chunks[1], amax_chunks[0]], dim=0)
split_gate = _split_gate_from_fc(decoder_type, module, name, fc_linear)
if split_gate:
# We have to split the gate from the fc
weights = torch.chunk(fc_linear.weight, 2, dim=0)
weight_scaling_factor = get_weight_scaling_factor(fc_linear)
weight_scaling_factors = None
if weight_scaling_factor is not None and weight_scaling_factor.numel() != 1:
# for Int8 SQ case, we split the weight scaling factor into two parts.
weight_scaling_factors = torch.chunk(weight_scaling_factor, 2, dim=0)
config.fc = build_linear_config(fc_linear, LINEAR_COLUMN)
config.gate = build_linear_config(fc_linear, LINEAR_COLUMN)
config.fc.weight = weights[0]
config.gate.weight = weights[1]
if weight_scaling_factors is not None:
config.fc.weights_scaling_factor = weight_scaling_factors[0]
config.gate.weights_scaling_factor = weight_scaling_factors[1]
else:
config.fc = build_linear_config(fc_linear, LINEAR_COLUMN)
if proj_linear is not None:
config.proj = build_linear_config(proj_linear, LINEAR_ROW)
if gate_linear is not None:
config.gate = build_linear_config(gate_linear, LINEAR_COLUMN)
assert config.proj is not None and config.fc is not None, "proj or fc can not be found"
# Override hidden_act based on decoder_type
if decoder_type in ["bloom", "glm"]:
hidden_act = "gelu"
if decoder_type == "phi3":
hidden_act = "swiglu"
if hidden_act is None:
if hasattr(module, "activation"):
hidden_act = module.activation
elif hasattr(module, "activation_func"):
# MCore activation_func can be swiglu (gated silu) or squared_relu.
hidden_act = module.activation_func.__name__.replace("_", "-")
if hidden_act in ["glu", "silu"]:
hidden_act = "swiglu" if decoder_type == "gpt" else "silu"
else:
for act in ["act", "act_fn", "activation_fn"]:
if hasattr(module, act):
hidden_act = _get_hidden_act(getattr(module, act)).split("_")[0]
break
if hidden_act is None and decoder_type == "qwen":
# for v1 qwen versions, activation is not explicitly defined as part of the layer's implementation
hidden_act = "silu"
if hidden_act is None:
raise NotImplementedError(f"{module} not supported.")
config.hidden_act = hidden_act
return config
def _get_expert_attr(experts: nn.Module, export_id: int, linear_name: str):
# Generic expert attribute accessor.
# Works for most MoE models that store experts as a list/ModuleList where
# each expert has linear layers as direct attributes:
# experts[0].w1, experts[0].w2, experts[0].w3 (Mixtral)
# experts[0].gate_proj, experts[0].down_proj, experts[0].up_proj (Qwen)
# experts[0].linear_fc1, experts[0].linear_fc2 (Llama MCore)
return getattr(experts[export_id], linear_name)
def _get_dbrx_expert(experts: nn.Module, export_id: int, linear_name: str):
# DBRX experts layout is:
# experts:
# w1[0]
# w1[1]
# ...
# w2[0]
# w2[1]
# ...
# v1[0]
# v1[1]
# ...
return getattr(experts, linear_name)[export_id]
def _build_stacked_linear(experts: nn.Module, module_name, linear_type, num_experts, expert_getter):
config = LinearConfig(linear_type=linear_type)
# weights
config.weight = torch.stack(
[expert_getter(experts, i, module_name).weight for i in range(num_experts)]
)
# bias
first_module = expert_getter(experts, 0, module_name)
if hasattr(first_module, "bias") and first_module.bias is not None:
raise ValueError("Unexpected bias tensors inside MOE modules.")
# scaling factors
def get_stacked_scaling_factors(experts, get_function, module_name):
expert_0_scaling_factor = get_function(expert_getter(experts, 0, module_name))
if expert_0_scaling_factor is None:
return None
dtype = expert_0_scaling_factor.dtype
scaling_factors = [
get_function(expert_getter(experts, i, module_name)) for i in range(num_experts)
]
if dtype == torch.float8_e4m3fn:
scaling_factors = [sf.view(torch.uint8) for sf in scaling_factors]
return torch.stack(scaling_factors).view(torch.float8_e4m3fn)
return torch.stack(scaling_factors)
config.activation_scaling_factor = get_stacked_scaling_factors(
experts, get_activation_scaling_factor, module_name
)
# The moe plugin only supports a single activation scaling factor for all experts
if config.activation_scaling_factor is not None:
config.activation_scaling_factor = config.activation_scaling_factor.max().unsqueeze(0)
config.weights_scaling_factor = get_stacked_scaling_factors(
experts, get_weight_scaling_factor, module_name
)
config.weights_scaling_factor_2 = get_stacked_scaling_factors(
experts, get_weight_scaling_factor_2, module_name
)
config.prequant_scaling_factor = get_stacked_scaling_factors(
experts, get_prequant_scaling_factor, module_name
)
config.awq_block_size = get_weight_block_size(expert_getter(experts, 0, module_name))
config.quantization = get_quantization_format(experts)
return config
def get_expert_linear_names(module: nn.Module) -> list[str]:
"""Get the list of linear names for the experts."""
def module_match_name_list(module, name_list):
"""Check if the module name matches any of the names in the list.
e.g. module_match_name_list(QuantQwen3MoeSparseMoeBlock, ['Qwen3MoeSparseMoeBlock']) -> True
"""
return any(name.lower() in type(module).__name__.lower() for name in name_list)
if module_match_name_list(
module,
[
"Qwen2MoeSparseMoeBlock",
"Qwen3MoeSparseMoeBlock",
"Qwen3NextSparseMoeBlock",
"Qwen3_5MoeSparseMoeBlock",
"Qwen3VLMoeTextSparseMoeBlock",
"DeepseekMoE",
],
):
return ["gate_proj", "down_proj", "up_proj"]
elif module_match_name_list(module, ["MixtralMoeSparseMoeBlock"]):
return ["linear_fc1", "linear_fc2"]
elif module_match_name_list(module, ["DBRXMoeSparseMoeBlock"]):
return ["w1_linear", "w2_linear", "v1_linear"]
elif module_match_name_list(module, ["GptOssMoE"]):
# GPT-OSS MoE modules use gate_up_proj and down_proj
return ["gate_up_proj", "down_proj"]
else:
# assuming w1, w2, w3 by default
return ["w1", "w2", "w3"]
def set_expert_quantizer_amax(
modules: nn.Module | list[nn.Module],
quantizer_attrs: str | list[str] | None = None,
fallback_value: float = 0.5,
device: torch.device | None = None,
) -> list[nn.Module]:
"""Set amax values for expert quantizers using smart fallback logic.
Uses smart fallback logic: