-
Notifications
You must be signed in to change notification settings - Fork 512
Expand file tree
/
Copy pathparam_mapping.py
More file actions
2438 lines (2143 loc) · 114 KB
/
param_mapping.py
File metadata and controls
2438 lines (2143 loc) · 114 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parameter mappings and transformation hooks for checkpoint conversion.
This module defines the necessary components to convert model checkpoints between
MaxText and Hugging Face formats for various architectures (e.g., Gemma, Qwen).
It provides two key types of mappings for each model:
1. **Parameter Name Mappings (`PARAM_MAPPING`)**: Dictionaries that map a MaxText
parameter key to its corresponding Hugging Face parameter(s). These mappings are
generated by functions like `GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING`.
**Key: MaxText parameters, with following forms:**
- `atomic_mt_key`: A single string representing one MaxText parameter.
- `composite_mt_key`: A tuple of strings representing multiple MaxText parameters. (e.g., GPT-OSS)
**Value: corresponding Hugging Face parameters, with following forms:**
- `unscanned`: A single string.
- `scanned`: A list of strings, to be stacked along the layer axis.
- `unscanned with expert stacking`: A list of strings, to be stacked along the expert axis.
- `scanned with expert stacking`: A nested list of strings, to be stacked along both layer and expert axes.
Note: Expert stacking only applies a subset of MoE models (e.g., Qwen MoE, DeepSeek, Mixtral),
but not others (e.g., GPT-OSS).
2. **Hook Functions (`HOOK_FNS`)**: Dictionaries that map a MaxText parameter
name to a specific transformation function (a "hook"). These hooks handle
the actual value conversion, which can include operations like reshaping,
transposing, scaling, or padding tensors to match the target format's
requirements. These are generated by functions like
`GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN`.
The main conversion script uses these mappings to systematically transform each
parameter from the source checkpoint and build the target checkpoint.
"""
import warnings
import numpy as np
import jax
import jax.numpy as jnp
def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Generates a parameter mapping from MaxText to Hugging Face for Gemma3.
This function creates a dictionary that maps the parameter names from a
MaxText Gemma3 checkpoint to their corresponding names in the Hugging Face
`Gemma3ForCausalLM` format. It handles both the text and vision components
of the model.
Args:
config (dict): The Hugging Face model configuration dictionary, which must
contain 'text_config' and 'vision_config' sub-dictionaries.
scan_layers (bool, optional): If True, generates mappings for scanned
layers, where multiple layers are stacked into a single tensor. If False,
generates mappings for individual, unscanned layers. Defaults to False.
Returns:
dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). Values
are either a single Hugging Face parameter name (unscanned form) or a list of
Hugging Face parameter names (scanned form) for stacked text layers.
"""
tcfg = config["text_config"]
vcfg = config["vision_config"]
Ndec = tcfg["num_hidden_layers"]
Nvis = vcfg["num_hidden_layers"]
# pylint: disable=line-too-long
mapping = {
# Embedding & final norm
"params-token_embedder-embedding": "model.language_model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.language_model.norm.weight",
# Vision embed & pos
"params-vision_encoder-Gemma3VisionEncoderLayer_0-embedding-kernel": "model.vision_tower.vision_model.embeddings.patch_embedding.weight",
"params-vision_encoder-Gemma3VisionEncoderLayer_0-embedding-bias": "model.vision_tower.vision_model.embeddings.patch_embedding.bias",
"params-vision_encoder-Gemma3VisionEncoderLayer_0-pos_embedding": "model.vision_tower.vision_model.embeddings.position_embedding.weight",
"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoder_norm-scale": "model.vision_tower.vision_model.post_layernorm.weight",
"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoder_norm-bias": "model.vision_tower.vision_model.post_layernorm.bias",
# Multi-modal projector
"params-vision_encoder-VisionEmbedder_0-mm_input_projection-w": "model.multi_modal_projector.mm_input_projection_weight",
"params-vision_encoder-VisionEmbedder_0-mm_soft_embedding_norm-scale": "model.multi_modal_projector.mm_soft_emb_norm.weight",
}
vision_params = [
("LayerNorm_0-scale", "layer_norm1.weight"),
("LayerNorm_0-bias", "layer_norm1.bias"),
("LayerNorm_1-scale", "layer_norm2.weight"),
("LayerNorm_1-bias", "layer_norm2.bias"),
("MultiHeadDotProductAttention_0-query-kernel", "self_attn.q_proj.weight"),
("MultiHeadDotProductAttention_0-query-bias", "self_attn.q_proj.bias"),
("MultiHeadDotProductAttention_0-key-kernel", "self_attn.k_proj.weight"),
("MultiHeadDotProductAttention_0-key-bias", "self_attn.k_proj.bias"),
("MultiHeadDotProductAttention_0-value-kernel", "self_attn.v_proj.weight"),
("MultiHeadDotProductAttention_0-value-bias", "self_attn.v_proj.bias"),
("MultiHeadDotProductAttention_0-out-kernel", "self_attn.out_proj.weight"),
("MultiHeadDotProductAttention_0-out-bias", "self_attn.out_proj.bias"),
("MlpBlockViT_0-Dense_0-kernel", "mlp.fc1.weight"),
("MlpBlockViT_0-Dense_0-bias", "mlp.fc1.bias"),
("MlpBlockViT_0-Dense_1-kernel", "mlp.fc2.weight"),
("MlpBlockViT_0-Dense_1-bias", "mlp.fc2.bias"),
]
# Vision layers mapping
for i in range(Nvis):
for mx, hf in vision_params:
key = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-{mx}"
mapping[key] = f"model.vision_tower.vision_model.encoder.layers.{i}.{hf}"
# Text decoder mapping
text_params = [
("pre_self_attention_norm-scale", "input_layernorm.weight"),
("post_self_attention_norm-scale", "post_attention_layernorm.weight"),
("self_attention-query_norm-scale", "self_attn.q_norm.weight"),
("self_attention-key_norm-scale", "self_attn.k_norm.weight"),
("pre_ffw_norm-scale", "pre_feedforward_layernorm.weight"),
("post_ffw_norm-scale", "post_feedforward_layernorm.weight"),
("self_attention-query-kernel", "self_attn.q_proj.weight"),
("self_attention-key-kernel", "self_attn.k_proj.weight"),
("self_attention-value-kernel", "self_attn.v_proj.weight"),
("self_attention-out-kernel", "self_attn.o_proj.weight"),
("mlp-wi_0-kernel", "mlp.gate_proj.weight"),
("mlp-wi_1-kernel", "mlp.up_proj.weight"),
("mlp-wo-kernel", "mlp.down_proj.weight"),
]
if scan_layers:
# Gemma3 repeats a 6-layer attention pattern (5 local + 1 global),
# scanned as layers_0..layers_5 with leftovers in layers_remainder.
attention_pattern_length = 6
num_remaining = Ndec % attention_pattern_length
num_scanned = Ndec - num_remaining
# Main scanned blocks: params-decoder-layers-layers_{block_idx}-{param}
for block_idx in range(attention_pattern_length):
hf_indices = list(range(block_idx, num_scanned, attention_pattern_length))
for mx, hf in text_params:
key = f"params-decoder-layers-layers_{block_idx}-{mx}"
mapping[key] = [f"model.language_model.layers.{i}.{hf}" for i in hf_indices]
# Remainder layers (unscanned): params-decoder-layers_remainder-layers_{rem_idx}-{param}
if num_remaining > 0:
for rem_idx in range(num_remaining):
hf_layer_idx = num_scanned + rem_idx
for mx, hf in text_params:
key = f"params-decoder-layers_remainder-layers_{rem_idx}-{mx}"
mapping[key] = f"model.language_model.layers.{hf_layer_idx}.{hf}"
else:
for i in range(Ndec):
for mx, hf in text_params:
key = f"params-decoder-layers_{i}-{mx}"
mapping[key] = f"model.language_model.layers.{i}.{hf}"
return mapping
def GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Hook functions for Gemma3 parameter conversion.
This function provides a dictionary of transformation functions (hooks) for
converting Gemma3 model parameters between MaxText and Hugging Face formats.
It handles embedding padding/scaling, RMSNorm scaling, kernel reshaping, and
vision-specific tensor manipulations.
Args:
config (dict): The Hugging Face model configuration dictionary.
scan_layers (bool, optional): Whether the model uses scanned layers.
Defaults to False.
saving_to_hf (bool, optional): The direction of conversion. True for
MaxText to Hugging Face, False for the reverse. Defaults to False.
Returns:
dict: A dictionary mapping MaxText parameter names to their corresponding
transformation functions.
"""
hooks = {}
# ---- Embedding pad & scale ----
def pad_and_scale_embedding(input_tensor, target_shape):
source_vocab_size, _ = input_tensor.shape
target_vocab_size, target_hidden_size = target_shape
# MaxText embedding = original_embedding * sqrt(hidden_size)
# HF embedding = original_embedding (HF model forward pass applies scaling)
# Note: config["hidden_size"] is the HF hidden size from the HF config object
normalizer = np.dtype("bfloat16").type(config["text_config"]["hidden_size"] ** 0.5)
# Apply scaling first
if saving_to_hf: # MaxText to HF
scaled_tensor = (input_tensor / normalizer).astype(input_tensor.dtype)
else: # HF to MaxText
scaled_tensor = (input_tensor * normalizer).astype(input_tensor.dtype)
# Handle padding/truncation
if source_vocab_size > target_vocab_size:
warnings.warn(
f"source vocab={source_vocab_size} > target vocab={target_vocab_size}, truncate output layer for MaxText."
)
output_tensor = scaled_tensor[:target_vocab_size, :]
elif source_vocab_size < target_vocab_size:
warnings.warn(f"source vocab={source_vocab_size} < target vocab={target_vocab_size}, pad output layer for MaxText.")
padding_shape = (target_vocab_size - source_vocab_size, target_hidden_size)
# Use jnp.zeros for JAX arrays, np.zeros for numpy arrays
padding = (
jnp.zeros(padding_shape, dtype=scaled_tensor.dtype)
if isinstance(scaled_tensor, jax.Array)
else np.zeros(padding_shape, dtype=scaled_tensor.dtype)
)
output_tensor = (
jnp.concatenate([scaled_tensor, padding], axis=0)
if isinstance(scaled_tensor, jax.Array)
else np.concatenate([scaled_tensor, padding], axis=0)
)
else: # Vocab sizes match
output_tensor = scaled_tensor
return output_tensor
# ---- RMSNorm scale ----
def scale_rmsnorm(x, target_shape):
# MaxText norm = HF norm +1; HF norm = MaxText norm -1
if saving_to_hf:
return (x - 1.0).reshape(target_shape)
return (x + 1.0).reshape(target_shape)
# ---- Generic reshape ----
def reshape_kernel(x, target_shape):
if saving_to_hf:
flipped = np.flip(np.array(target_shape))
return x.reshape(flipped).T
else:
return x.T.reshape(target_shape)
# ---- Vision reshape ----
def vis_bias(x, target_shape):
if saving_to_hf:
return x.flatten()
else:
return x.reshape(target_shape)
def vision_patch(x, target_shape):
if saving_to_hf:
return x.transpose(3, 2, 0, 1)
else:
return x.transpose(2, 3, 1, 0)
def pos_embed(x, target_shape):
if saving_to_hf:
return x.squeeze(0)
return x[None, :, :]
# ---Embedding & final norm---
hooks["params-token_embedder-embedding"] = pad_and_scale_embedding
hooks["params-decoder-decoder_norm-scale"] = scale_rmsnorm
# [1, 4096, 1152]
hooks["params-vision_encoder-Gemma3VisionEncoderLayer_0-embedding-kernel"] = vision_patch
hooks["params-vision_encoder-Gemma3VisionEncoderLayer_0-pos_embedding"] = pos_embed
hooks["params-vision_encoder-VisionEmbedder_0-mm_input_projection-w"] = lambda x, _: x
hooks["params-vision_encoder-VisionEmbedder_0-mm_soft_embedding_norm-scale"] = scale_rmsnorm
# Text layers
tc = config.get("text_config", {})
nlayers = tc.get("num_hidden_layers", 0)
if scan_layers:
attention_pattern_length = 6
num_remaining = nlayers % attention_pattern_length
# Scanned sub-layer prefixes
prefixes = [f"params-decoder-layers-layers_{block_idx}-" for block_idx in range(attention_pattern_length)]
# Remainder sub-layer prefixes
if num_remaining > 0:
prefixes += [f"params-decoder-layers_remainder-layers_{rem_idx}-" for rem_idx in range(num_remaining)]
else:
prefixes = [f"params-decoder-layers_{i}-" for i in range(nlayers)]
for pref in prefixes:
# Attention Q/K/V/O
hooks[pref + "self_attention-query-kernel"] = reshape_kernel
hooks[pref + "self_attention-key-kernel"] = reshape_kernel
hooks[pref + "self_attention-value-kernel"] = reshape_kernel
hooks[pref + "self_attention-out-kernel"] = reshape_kernel
# Norm scales
for nm in [
"pre_self_attention_norm-scale",
"post_self_attention_norm-scale",
"self_attention-query_norm-scale",
"self_attention-key_norm-scale",
"pre_ffw_norm-scale",
"post_ffw_norm-scale",
]:
hooks[pref + nm] = scale_rmsnorm
# MLP
hooks[pref + "mlp-wi_0-kernel"] = reshape_kernel
hooks[pref + "mlp-wi_1-kernel"] = reshape_kernel
hooks[pref + "mlp-wo-kernel"] = reshape_kernel
# Vision layers
vc = config.get("vision_config", {})
nvis = vc.get("num_hidden_layers", 0)
for i in range(nvis):
base = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-"
# Attention kernels & biases
for qkv in ["query", "key", "value"]:
hooks[base + f"MultiHeadDotProductAttention_0-{qkv}-kernel"] = reshape_kernel
hooks[base + f"MultiHeadDotProductAttention_0-{qkv}-bias"] = vis_bias
# [1152, 1152] -> [16, 72, 1152]
hooks[base + "MultiHeadDotProductAttention_0-out-kernel"] = reshape_kernel
hooks[base + "MultiHeadDotProductAttention_0-out-bias"] = vis_bias
# MLP ViT kernels & biases
for dense in ["Dense_0", "Dense_1"]:
hooks[base + f"MlpBlockViT_0-{dense}-kernel"] = reshape_kernel
return hooks
def GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Returns mapping between MaxText and HuggingFace Gemma2 weight paths.
Args:
config (dict): Model configuration dictionary containing at least
'num_hidden_layers'.
scan_layers (bool, optional): Whether the MaxText model uses layer
scanning optimization. When True, decoder layers are stacked into a
single tensor. Defaults to False.
Returns:
dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter name).
Values are either a single string (unscanned form) or a list of strings
(scanned form) for stacked layers when `scan_layers=True`.
Notes:
- MaxText uses a paired layer approach where two HF decoder layers are
treated as one MaxText decoder layer.
- MaxText layer `i` corresponds to HF layers `2i` and `2i+1`.
- Local components map to even-numbered HF decoder layers (0, 2, 4...).
- Global components map to odd-numbered HF decoder layers (1, 3, 5...).
"""
nlayers = config["num_hidden_layers"]
mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.norm.weight",
}
if scan_layers:
mapping = {
**mapping,
"params-decoder-layers-pre_self_attention_norm_global-scale": [
f"model.layers.{i}.input_layernorm.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-mlp_global-wo-kernel": [
f"model.layers.{i}.mlp.down_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-mlp_global-wi_1-kernel": [
f"model.layers.{i}.mlp.up_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-mlp_global-wi_0-kernel": [
f"model.layers.{i}.mlp.gate_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-post_self_attention_norm_global-scale": [
f"model.layers.{i}.post_attention_layernorm.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-post_ffw_norm_global-scale": [
f"model.layers.{i}.post_feedforward_layernorm.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-pre_ffw_norm_global-scale": [
f"model.layers.{i}.pre_feedforward_layernorm.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-self_attention_global-key-kernel": [
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-self_attention_global-out-kernel": [
f"model.layers.{i}.self_attn.o_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-self_attention_global-query-kernel": [
f"model.layers.{i}.self_attn.q_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-self_attention_global-value-kernel": [
f"model.layers.{i}.self_attn.v_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-pre_self_attention_norm_local-scale": [
f"model.layers.{i}.input_layernorm.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-mlp_local-wo-kernel": [
f"model.layers.{i}.mlp.down_proj.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-mlp_local-wi_1-kernel": [
f"model.layers.{i}.mlp.up_proj.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-mlp_local-wi_0-kernel": [
f"model.layers.{i}.mlp.gate_proj.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-post_self_attention_norm_local-scale": [
f"model.layers.{i}.post_attention_layernorm.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-post_ffw_norm_local-scale": [
f"model.layers.{i}.post_feedforward_layernorm.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-pre_ffw_norm_local-scale": [
f"model.layers.{i}.pre_feedforward_layernorm.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-self_attention_local-key-kernel": [
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-self_attention_local-out-kernel": [
f"model.layers.{i}.self_attn.o_proj.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-self_attention_local-query-kernel": [
f"model.layers.{i}.self_attn.q_proj.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-self_attention_local-value-kernel": [
f"model.layers.{i}.self_attn.v_proj.weight" for i in range(0, nlayers, 2)
],
}
# Case 2: scan_layer=False
else:
for maxtext_layer_idx in range(0, nlayers // 2):
local_layer_idx = maxtext_layer_idx * 2
global_layer_idx = maxtext_layer_idx * 2 + 1
# pylint: disable=line-too-long
layer_mapping = {
f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_global-scale": f"model.layers.{global_layer_idx}.input_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wo-kernel": f"model.layers.{global_layer_idx}.mlp.down_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_1-kernel": f"model.layers.{global_layer_idx}.mlp.up_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_0-kernel": f"model.layers.{global_layer_idx}.mlp.gate_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_global-scale": f"model.layers.{global_layer_idx}.post_attention_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_global-scale": f"model.layers.{global_layer_idx}.post_feedforward_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_global-scale": f"model.layers.{global_layer_idx}.pre_feedforward_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": f"model.layers.{global_layer_idx}.self_attn.k_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-out-kernel": f"model.layers.{global_layer_idx}.self_attn.o_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-query-kernel": f"model.layers.{global_layer_idx}.self_attn.q_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-value-kernel": f"model.layers.{global_layer_idx}.self_attn.v_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_local-scale": f"model.layers.{local_layer_idx}.input_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wo-kernel": f"model.layers.{local_layer_idx}.mlp.down_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_1-kernel": f"model.layers.{local_layer_idx}.mlp.up_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_0-kernel": f"model.layers.{local_layer_idx}.mlp.gate_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_local-scale": f"model.layers.{local_layer_idx}.post_attention_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_local-scale": f"model.layers.{local_layer_idx}.post_feedforward_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_local-scale": f"model.layers.{local_layer_idx}.pre_feedforward_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-key-kernel": f"model.layers.{local_layer_idx}.self_attn.k_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-out-kernel": f"model.layers.{local_layer_idx}.self_attn.o_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-query-kernel": f"model.layers.{local_layer_idx}.self_attn.q_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-value-kernel": f"model.layers.{local_layer_idx}.self_attn.v_proj.weight",
}
mapping = {**mapping, **layer_mapping}
return mapping
def GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Creates parameter transformation functions for Gemma2 conversion.
This function generates a mapping of transformation functions that handle the
necessary conversions between MaxText and HuggingFace parameter formats for
Gemma2, including operations like padding, reshaping, and scaling.
Args:
config (dict): Model configuration dictionary that must contain:
- num_hidden_layers (int): Number of layers in the model.
- head_dim (int): Dimension of attention heads.
- hidden_size (int): Model's hidden dimension size.
scan_layers (bool, optional): Controls the output format for layer
parameters. True for batched, False for individual. Defaults to False.
saving_to_hf (bool, optional): Determines the direction of transformation.
True for MaxText to HuggingFace, False for the reverse. Defaults to
False.
Returns:
dict: A mapping from MaxText parameter names to transformation functions.
The value can be a single function or a list of functions to be
applied sequentially.
"""
nlayers = config["num_hidden_layers"]
def pad_hf_embedding_layer(input_tensor, target_shape):
"""Pads/unpads and scales the embedding layer.
Note:
HF embedding weights shape = [256000, d_model]
MaxText embedding weights shape = [256128, d_model]
MaxText pads Gemma2 embedding to 256128 for better performance.
"""
# TODO(wenxindongwork), Perhaps, this dtype should be the activation dtype
normalizer = np.dtype("float32").type(config["hidden_size"] ** 0.5)
if saving_to_hf:
target_tensor = input_tensor[: target_shape[0], : target_shape[1]]
target_tensor = target_tensor / normalizer
target_tensor = target_tensor.astype(input_tensor.dtype)
return target_tensor
else:
target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)
target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor
target_tensor = target_tensor * normalizer
target_tensor = target_tensor.astype(input_tensor.dtype)
return target_tensor
def reshape_kernel(input_tensor, target_shape):
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)
def scale_rmsnorm_layer(input_tensor, target_shape):
if saving_to_hf:
return (input_tensor - 1.0).reshape(target_shape)
else:
return (input_tensor + 1.0).reshape(target_shape)
def scale_query_layer(input_tensor, target_shape):
if saving_to_hf:
depth_scale = np.dtype("float32").type(np.sqrt(config["head_dim"]))
return (input_tensor * depth_scale).astype(input_tensor.dtype)
else:
depth_scale = np.dtype("float32").type(1 / np.sqrt(config["head_dim"]))
return (input_tensor * depth_scale).astype(input_tensor.dtype)
# hook order does not affect result
query_hook_chain = [reshape_kernel, scale_query_layer]
mapping = {
"params-token_embedder-embedding": pad_hf_embedding_layer,
"params-decoder-decoder_norm-scale": scale_rmsnorm_layer,
}
if scan_layers:
mapping = {
**mapping,
"params-decoder-layers-self_attention_global-query-kernel": query_hook_chain,
"params-decoder-layers-self_attention_local-query-kernel": query_hook_chain,
"params-decoder-layers-self_attention_global-key-kernel": reshape_kernel,
"params-decoder-layers-self_attention_local-key-kernel": reshape_kernel,
"params-decoder-layers-self_attention_global-value-kernel": reshape_kernel,
"params-decoder-layers-self_attention_local-value-kernel": reshape_kernel,
"params-decoder-layers-mlp_global-wo-kernel": reshape_kernel,
"params-decoder-layers-mlp_global-wi_1-kernel": reshape_kernel,
"params-decoder-layers-mlp_global-wi_0-kernel": reshape_kernel,
"params-decoder-layers-self_attention_global-out-kernel": reshape_kernel,
"params-decoder-layers-mlp_local-wo-kernel": reshape_kernel,
"params-decoder-layers-mlp_local-wi_1-kernel": reshape_kernel,
"params-decoder-layers-mlp_local-wi_0-kernel": reshape_kernel,
"params-decoder-layers-self_attention_local-out-kernel": reshape_kernel,
"params-decoder-layers-pre_self_attention_norm_global-scale": scale_rmsnorm_layer,
"params-decoder-layers-post_self_attention_norm_global-scale": scale_rmsnorm_layer,
"params-decoder-layers-post_ffw_norm_global-scale": scale_rmsnorm_layer,
"params-decoder-layers-pre_ffw_norm_global-scale": scale_rmsnorm_layer,
"params-decoder-layers-pre_self_attention_norm_local-scale": scale_rmsnorm_layer,
"params-decoder-layers-post_self_attention_norm_local-scale": scale_rmsnorm_layer,
"params-decoder-layers-post_ffw_norm_local-scale": scale_rmsnorm_layer,
"params-decoder-layers-pre_ffw_norm_local-scale": scale_rmsnorm_layer,
}
else:
for maxtext_layer_idx in range(nlayers // 2):
mapping = {
**mapping,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-query-kernel": query_hook_chain,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-query-kernel": query_hook_chain,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-key-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-value-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-value-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wo-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_1-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_0-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-out-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wo-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_1-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_0-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-out-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_global-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_global-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_global-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_global-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_local-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_local-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_local-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_local-scale": scale_rmsnorm_layer,
}
return mapping
def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Returns mapping from MaxText to HuggingFace Qwen weight paths.
This function generates a dictionary that maps parameter names from a MaxText
Qwen checkpoint to their corresponding names in the Hugging Face format.
It handles both dense and Mixture-of-Experts (MoE) model variants.
Args:
config (dict): Model configuration dictionary, including
'num_hidden_layers' and optionally 'num_experts'.
scan_layers (bool, optional): Whether the MaxText model uses scanned
layers. Defaults to False.
Returns:
dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names).
Values are Hugging Face parameter names in one of four forms: unscanned (string),
scanned (list of strings), unscanned with expert stacking (list of strings),
or scanned with expert stacking (nested list of strings).
"""
n_layers = config["num_hidden_layers"]
num_experts = config.get("num_experts", 0)
mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.norm.weight",
"params-decoder-logits_dense-kernel": "lm_head.weight",
}
if scan_layers:
# This block handles scanned layers for both dense and MoE models.
mapping.update(
{
"params-decoder-layers-pre_self_attention_layer_norm-scale": [
f"model.layers.{i}.input_layernorm.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-query-kernel": [
f"model.layers.{i}.self_attn.q_proj.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-key-kernel": [
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-value-kernel": [
f"model.layers.{i}.self_attn.v_proj.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-query-bias": [
f"model.layers.{i}.self_attn.q_proj.bias" for i in range(n_layers)
],
"params-decoder-layers-self_attention-key-bias": [
f"model.layers.{i}.self_attn.k_proj.bias" for i in range(n_layers)
],
"params-decoder-layers-self_attention-value-bias": [
f"model.layers.{i}.self_attn.v_proj.bias" for i in range(n_layers)
],
"params-decoder-layers-self_attention-out-kernel": [
f"model.layers.{i}.self_attn.o_proj.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-query_norm-scale": [
f"model.layers.{i}.self_attn.q_norm.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-key_norm-scale": [
f"model.layers.{i}.self_attn.k_norm.weight" for i in range(n_layers)
],
"params-decoder-layers-post_self_attention_layer_norm-scale": [
f"model.layers.{i}.post_attention_layernorm.weight" for i in range(n_layers)
],
}
)
if num_experts > 1:
# For scanned MoE, we create a nested list: [[e0_l0, e0_l1..], [e1_l0, e1_l1..]..]
# This follows the (experts, layers, ...) tensor layout.
mapping.update(
{
"params-decoder-layers-moe_block-gate-kernel": [
f"model.layers.{i}.mlp.gate.weight" for i in range(n_layers)
],
"params-decoder-layers-moe_block-wi_0": [
[f"model.layers.{l}.mlp.experts.{e}.gate_proj.weight" for l in range(n_layers)]
for e in range(num_experts)
],
"params-decoder-layers-moe_block-wi_1": [
[f"model.layers.{l}.mlp.experts.{e}.up_proj.weight" for l in range(n_layers)]
for e in range(num_experts)
],
"params-decoder-layers-moe_block-wo": [
[f"model.layers.{l}.mlp.experts.{e}.down_proj.weight" for l in range(n_layers)]
for e in range(num_experts)
],
}
)
else: # Dense MLP
mapping.update(
{
"params-decoder-layers-mlp-wi_0-kernel": [
f"model.layers.{i}.mlp.gate_proj.weight" for i in range(n_layers)
],
"params-decoder-layers-mlp-wi_1-kernel": [f"model.layers.{i}.mlp.up_proj.weight" for i in range(n_layers)],
"params-decoder-layers-mlp-wo-kernel": [f"model.layers.{i}.mlp.down_proj.weight" for i in range(n_layers)],
}
)
else: # unscanned layers
for i in range(n_layers):
# Common Attention and Norms
# pylint: disable=line-too-long
mapping.update(
{
f"params-decoder-layers_{i}-pre_self_attention_layer_norm-scale": f"model.layers.{i}.input_layernorm.weight",
f"params-decoder-layers_{i}-self_attention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight",
f"params-decoder-layers_{i}-self_attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight",
f"params-decoder-layers_{i}-self_attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight",
f"params-decoder-layers_{i}-self_attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight",
f"params-decoder-layers_{i}-self_attention-query-bias": f"model.layers.{i}.self_attn.q_proj.bias",
f"params-decoder-layers_{i}-self_attention-key-bias": f"model.layers.{i}.self_attn.k_proj.bias",
f"params-decoder-layers_{i}-self_attention-value-bias": f"model.layers.{i}.self_attn.v_proj.bias",
f"params-decoder-layers_{i}-self_attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight",
f"params-decoder-layers_{i}-self_attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight",
f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight",
f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight",
}
)
if num_experts > 1:
# For each unscanned MoE layer, map the MaxText parameter to a 1D list of all expert weights for that layer.
mapping.update(
{
f"params-decoder-layers_{i}-moe_block-gate-kernel": f"model.layers.{i}.mlp.gate.weight",
f"params-decoder-layers_{i}-moe_block-wi_0": [
f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight" for j in range(num_experts)
],
f"params-decoder-layers_{i}-moe_block-wi_1": [
f"model.layers.{i}.mlp.experts.{j}.up_proj.weight" for j in range(num_experts)
],
f"params-decoder-layers_{i}-moe_block-wo": [
f"model.layers.{i}.mlp.experts.{j}.down_proj.weight" for j in range(num_experts)
],
}
)
else: # Dense MLP
mapping.update(
{
f"params-decoder-layers_{i}-mlp-wi_0-kernel": f"model.layers.{i}.mlp.gate_proj.weight",
f"params-decoder-layers_{i}-mlp-wi_1-kernel": f"model.layers.{i}.mlp.up_proj.weight",
f"params-decoder-layers_{i}-mlp-wo-kernel": f"model.layers.{i}.mlp.down_proj.weight",
}
)
return mapping
def QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Creates parameter transformation functions for Qwen.
This function provides a dictionary of transformation functions (hooks) for
converting Qwen model parameters between MaxText and Hugging Face formats.
It handles embedding padding and kernel reshaping.
Args:
config (dict): Model configuration dictionary, including
'num_hidden_layers' and optionally 'num_experts'.
scan_layers (bool, optional): Whether the model uses scanned layers.
Defaults to False.
saving_to_hf (bool, optional): The direction of conversion. True for
MaxText to Hugging Face, False for the reverse. Defaults to False.
Returns:
dict: A dictionary mapping MaxText parameter names to their corresponding
transformation functions.
"""
n_layers = config["num_hidden_layers"]
num_experts = config.get("num_experts", 0)
def pad_embedding_layer(input_tensor, target_shape):
"""Pads or truncates embedding layer to match target vocab size."""
source_vocab_size = input_tensor.shape[0]
target_vocab_size = target_shape[0]
if source_vocab_size == target_vocab_size:
return input_tensor
if saving_to_hf: # MaxText to HF, truncate
return input_tensor[:target_vocab_size, :]
else: # HF to MaxText, pad
padded_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)
padded_tensor[:source_vocab_size, :] = input_tensor
return padded_tensor
def reshape_kernel(input_tensor, target_shape):
"""Reshapes and transposes kernel weights between MaxText and HF."""
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)
def reshape_bias(input_tensor, target_shape=None):
"""Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
# saving_to_hf: MaxText [heads, head_dim] -> HF [hidden_dim] (flatten)
# loading_to_maxtext: HF [hidden_dim] -> MaxText [heads, head_dim]
return input_tensor.reshape(target_shape)
mapping = {
"params-token_embedder-embedding": pad_embedding_layer,
"params-decoder-logits_dense-kernel": reshape_kernel,
}
kernel_hooks = [
"self_attention-query-kernel",
"self_attention-key-kernel",
"self_attention-value-kernel",
"self_attention-out-kernel",
"mlp-wi_0-kernel",
"mlp-wi_1-kernel",
"mlp-wo-kernel",
]
bias_hooks = [
"self_attention-query-bias",
"self_attention-key-bias",
"self_attention-value-bias",
]
moe_kernel_hooks = [
"moe_block-gate-kernel",
"moe_block-wi_0-kernel",
"moe_block-wi_1-kernel",
"moe_block-wo-kernel",
"moe_block-wi_0",
"moe_block-wi_1",
"moe_block-wo",
]
if scan_layers:
for key in kernel_hooks:
mapping[f"params-decoder-layers-{key}"] = reshape_kernel
for key in bias_hooks:
mapping[f"params-decoder-layers-{key}"] = reshape_bias
if num_experts > 1:
for key in moe_kernel_hooks:
mapping[f"params-decoder-layers-{key}"] = reshape_kernel
else:
for i in range(n_layers):
for key in kernel_hooks:
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel
for key in bias_hooks:
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_bias
if num_experts > 1:
for key in moe_kernel_hooks:
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel
return mapping
def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""
Returns mapping from MaxText to HuggingFace Qwen3-Next weight paths.
All MaxText keys start with 'params-' and use '-' separators for scanned layers.
"""
num_main_layers = config["num_hidden_layers"]
num_experts = config["num_experts"]
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
# 1. Non-layer specific weight mappings
mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.norm.weight",
"params-decoder-logits_dense-kernel": "lm_head.weight",
}
if scan_layers:
# 2. Scan over block cycles
for block_idx in range(layer_cycle_interval):
hf_indices = list(range(block_idx, num_main_layers, layer_cycle_interval))
prefix = f"params-decoder-layers-layer_{block_idx}"
# Layer norms
mapping[f"{prefix}-input_layernorm-scale"] = [f"model.layers.{i}.input_layernorm.weight" for i in hf_indices]
mapping[f"{prefix}-post_attention_layernorm-scale"] = [
f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices
]
# Handle Interleaved Attention (Linear vs Full)
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0
if is_full_attention_layer:
mapping.update(
{
f"{prefix}-attention-attention-query-kernel": [
f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices
],
f"{prefix}-attention-attention-key-kernel": [
f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices
],
f"{prefix}-attention-attention-value-kernel": [
f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices
],
f"{prefix}-attention-attention-out-kernel": [
f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices
],
f"{prefix}-attention-attention-query_norm-scale": [
f"model.layers.{i}.self_attn.q_norm.weight" for i in hf_indices
],
f"{prefix}-attention-attention-key_norm-scale": [
f"model.layers.{i}.self_attn.k_norm.weight" for i in hf_indices
],
}
)
else:
# Linear/Hybrid Attention Block
mapping.update(
{
f"{prefix}-attention-in_proj_qkvz-kernel": [
f"model.layers.{i}.linear_attn.in_proj_qkvz.weight" for i in hf_indices
],
f"{prefix}-attention-in_proj_ba-kernel": [
f"model.layers.{i}.linear_attn.in_proj_ba.weight" for i in hf_indices
],
f"{prefix}-attention-conv1d-kernel": [f"model.layers.{i}.linear_attn.conv1d.weight" for i in hf_indices],
f"{prefix}-attention-A_log": [f"model.layers.{i}.linear_attn.A_log" for i in hf_indices],
f"{prefix}-attention-dt_bias": [f"model.layers.{i}.linear_attn.dt_bias" for i in hf_indices],
f"{prefix}-attention-norm-rms_norm-scale": [
f"model.layers.{i}.linear_attn.norm.weight" for i in hf_indices
],
f"{prefix}-attention-out_proj-kernel": [
f"model.layers.{i}.linear_attn.out_proj.weight" for i in hf_indices
],
}
)
# 3. Handle MLP: Gates and Shared Experts
mapping.update(
{
f"{prefix}-mlp-routed_experts-gate-kernel": [f"model.layers.{i}.mlp.gate.weight" for i in hf_indices],
f"{prefix}-mlp-shared_expert-wi_0-kernel": [
f"model.layers.{i}.mlp.shared_expert.gate_proj.weight" for i in hf_indices
],
f"{prefix}-mlp-shared_expert-wi_1-kernel": [
f"model.layers.{i}.mlp.shared_expert.up_proj.weight" for i in hf_indices
],
f"{prefix}-mlp-shared_expert-wo-kernel": [
f"model.layers.{i}.mlp.shared_expert.down_proj.weight" for i in hf_indices
],
f"{prefix}-mlp-shared_expert_gate-kernel": [
f"model.layers.{i}.mlp.shared_expert_gate.weight" for i in hf_indices
],
}
)
# 4. Handle MoE Routed Experts
mapping.update(
{
f"{prefix}-mlp-routed_experts-wi_0": [
[f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for i in hf_indices] for e in range(num_experts)
],
f"{prefix}-mlp-routed_experts-wi_1": [
[f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for i in hf_indices] for e in range(num_experts)
],
f"{prefix}-mlp-routed_experts-wo": [
[f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for i in hf_indices] for e in range(num_experts)
],
}
)
else:
# Unscanned layer mapping
for i in range(num_main_layers):
prefix = f"params-decoder-layers_{i}"
# Layer Norms
mapping[f"{prefix}-input_layernorm-scale"] = f"model.layers.{i}.input_layernorm.weight"
mapping[f"{prefix}-post_attention_layernorm-scale"] = f"model.layers.{i}.post_attention_layernorm.weight"
# Determine layer type based on cycle interval
# Assuming block logic: layer i corresponds to block_idx = i % interval
block_idx = i % layer_cycle_interval
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0
if is_full_attention_layer:
mapping.update(
{
f"{prefix}-attention-attention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight",
f"{prefix}-attention-attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight",
f"{prefix}-attention-attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight",
f"{prefix}-attention-attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight",
f"{prefix}-attention-attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight",
f"{prefix}-attention-attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight",
}
)
else:
# Linear/Hybrid Attention Block
mapping.update(
{
f"{prefix}-attention-in_proj_qkvz-kernel": f"model.layers.{i}.linear_attn.in_proj_qkvz.weight",
f"{prefix}-attention-in_proj_ba-kernel": f"model.layers.{i}.linear_attn.in_proj_ba.weight",
f"{prefix}-attention-conv1d-kernel": f"model.layers.{i}.linear_attn.conv1d.weight",
f"{prefix}-attention-A_log": f"model.layers.{i}.linear_attn.A_log",
f"{prefix}-attention-dt_bias": f"model.layers.{i}.linear_attn.dt_bias",
f"{prefix}-attention-norm-rms_norm-scale": f"model.layers.{i}.linear_attn.norm.weight",
f"{prefix}-attention-out_proj-kernel": f"model.layers.{i}.linear_attn.out_proj.weight",
}
)
# MLP: Gates and Shared Experts
mapping.update(
{
f"{prefix}-mlp-routed_experts-gate-kernel": f"model.layers.{i}.mlp.gate.weight",
f"{prefix}-mlp-shared_expert-wi_0-kernel": f"model.layers.{i}.mlp.shared_expert.gate_proj.weight",
f"{prefix}-mlp-shared_expert-wi_1-kernel": f"model.layers.{i}.mlp.shared_expert.up_proj.weight",
f"{prefix}-mlp-shared_expert-wo-kernel": f"model.layers.{i}.mlp.shared_expert.down_proj.weight",
f"{prefix}-mlp-shared_expert_gate-kernel": f"model.layers.{i}.mlp.shared_expert_gate.weight",
}
)
# MoE Routed Experts (List of expert weights for this specific layer)
mapping.update(
{
f"{prefix}-mlp-routed_experts-wi_0": [
f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for e in range(num_experts)
],