@@ -29,15 +29,10 @@ def test_gemma3_mapping_unscanned(self):
2929 "vision_config" : {"num_hidden_layers" : 1 , "hidden_size" : 128 },
3030 }
3131 maxtext_config = mock .Mock ()
32- mapping = param_mapping .GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING (
33- config , maxtext_config , scan_layers = False
34- )
32+ mapping = param_mapping .GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
3533
3634 self .assertIn ("params-token_embedder-embedding" , mapping )
37- self .assertEqual (
38- mapping ["params-token_embedder-embedding" ],
39- "model.language_model.embed_tokens.weight" ,
40- )
35+ self .assertEqual (mapping ["params-token_embedder-embedding" ], "model.language_model.embed_tokens.weight" )
4136
4237 # Check text decoder layer 0
4338 self .assertIn ("params-decoder-layers_0-pre_self_attention_norm-scale" , mapping )
@@ -48,13 +43,10 @@ def test_gemma3_mapping_unscanned(self):
4843
4944 # Check vision encoder layer 0
5045 self .assertIn (
51- "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale" ,
52- mapping ,
46+ "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale" , mapping
5347 )
5448 self .assertEqual (
55- mapping [
56- "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale"
57- ],
49+ mapping ["params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale" ],
5850 "model.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight" ,
5951 )
6052
@@ -64,36 +56,23 @@ def test_gemma3_mapping_scanned(self):
6456 "vision_config" : {"num_hidden_layers" : 1 , "hidden_size" : 128 },
6557 }
6658 maxtext_config = mock .Mock ()
67- mapping = param_mapping .GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING (
68- config , maxtext_config , scan_layers = True
69- )
59+ mapping = param_mapping .GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
7060
7161 self .assertIn ("params-token_embedder-embedding" , mapping )
7262
7363 # Check scanned block mapping
74- self .assertIn (
75- "params-decoder-layers-layers_0-pre_self_attention_norm-scale" , mapping
76- )
77- self .assertIsInstance (
78- mapping ["params-decoder-layers-layers_0-pre_self_attention_norm-scale" ],
79- list ,
80- )
81- self .assertEqual (
82- len (
83- mapping ["params-decoder-layers-layers_0-pre_self_attention_norm-scale" ]
84- ),
85- 2 ,
86- )
64+ self .assertIn ("params-decoder-layers-layers_0-pre_self_attention_norm-scale" , mapping )
65+ self .assertIsInstance (mapping ["params-decoder-layers-layers_0-pre_self_attention_norm-scale" ], list )
66+ # Gemma3 repeats a 6-layer pattern. 12 layers means 2 of each.
67+ self .assertEqual (len (mapping ["params-decoder-layers-layers_0-pre_self_attention_norm-scale" ]), 2 )
8768
8869 def test_gemma3_hooks (self ):
8970 config = {
9071 "text_config" : {"num_hidden_layers" : 2 , "hidden_size" : 256 },
9172 "vision_config" : {"num_hidden_layers" : 1 , "hidden_size" : 128 },
9273 }
9374 maxtext_config = mock .Mock ()
94- hooks = param_mapping .GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN (
95- config , maxtext_config , scan_layers = False , saving_to_hf = True
96- )
75+ hooks = param_mapping .GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = True )
9776
9877 self .assertIn ("params-token_embedder-embedding" , hooks )
9978
@@ -112,35 +91,25 @@ def test_gemma2_mapping(self):
11291 "hidden_size" : 256 ,
11392 }
11493 maxtext_config = mock .Mock ()
115- mapping = param_mapping .GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING (
116- config , maxtext_config , scan_layers = False
117- )
94+ mapping = param_mapping .GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
11895
11996 self .assertIn ("params-token_embedder-embedding" , mapping )
12097 # Gemma2 maps MaxText layer i to HF layers 2i and 2i+1
121- self .assertIn (
122- "params-decoder-layers_0-pre_self_attention_norm_local-scale" , mapping
123- )
98+ self .assertIn ("params-decoder-layers_0-pre_self_attention_norm_local-scale" , mapping )
12499 self .assertEqual (
125- mapping ["params-decoder-layers_0-pre_self_attention_norm_local-scale" ],
126- "model.layers.0.input_layernorm.weight" ,
127- )
128- self .assertIn (
129- "params-decoder-layers_0-pre_self_attention_norm_global-scale" , mapping
100+ mapping ["params-decoder-layers_0-pre_self_attention_norm_local-scale" ], "model.layers.0.input_layernorm.weight"
130101 )
102+ self .assertIn ("params-decoder-layers_0-pre_self_attention_norm_global-scale" , mapping )
131103 self .assertEqual (
132- mapping ["params-decoder-layers_0-pre_self_attention_norm_global-scale" ],
133- "model.layers.1.input_layernorm.weight" ,
104+ mapping ["params-decoder-layers_0-pre_self_attention_norm_global-scale" ], "model.layers.1.input_layernorm.weight"
134105 )
135106
136107 def test_qwen_mapping_dense (self ):
137108 config = {
138109 "num_hidden_layers" : 2 ,
139110 }
140111 maxtext_config = mock .Mock ()
141- mapping = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (
142- config , maxtext_config , scan_layers = False
143- )
112+ mapping = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
144113
145114 self .assertIn ("params-token_embedder-embedding" , mapping )
146115 self .assertIn ("params-decoder-layers_0-mlp-wi_0-kernel" , mapping )
@@ -151,14 +120,10 @@ def test_qwen_mapping_moe(self):
151120 "num_experts" : 4 ,
152121 }
153122 maxtext_config = mock .Mock ()
154- mapping = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (
155- config , maxtext_config , scan_layers = False
156- )
123+ mapping = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
157124
158125 self .assertIn ("params-decoder-layers_0-moe_block-wi_0" , mapping )
159- self .assertIsInstance (
160- mapping ["params-decoder-layers_0-moe_block-wi_0" ], list
161- )
126+ self .assertIsInstance (mapping ["params-decoder-layers_0-moe_block-wi_0" ], list )
162127 self .assertEqual (len (mapping ["params-decoder-layers_0-moe_block-wi_0" ]), 4 )
163128
164129 def test_qwen3_next_mapping (self ):
@@ -168,9 +133,7 @@ def test_qwen3_next_mapping(self):
168133 }
169134 maxtext_config = mock .Mock ()
170135 maxtext_config .inhomogeneous_layer_cycle_interval = 2
171- mapping = param_mapping .QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING (
172- config , maxtext_config , scan_layers = False
173- )
136+ mapping = param_mapping .QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
174137
175138 self .assertIn ("params-token_embedder-embedding" , mapping )
176139 self .assertIn ("params-decoder-layers_0-input_layernorm-scale" , mapping )
@@ -182,43 +145,32 @@ def test_deepseek_mapping(self):
182145 "n_routed_experts" : 2 ,
183146 }
184147 maxtext_config = mock .Mock ()
185- mapping = param_mapping .DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING (
186- config , maxtext_config , scan_layers = False
187- )
148+ mapping = param_mapping .DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
188149
189150 self .assertIn ("params-token_embedder-embedding" , mapping )
190151 # Layer 0 is dense
191152 self .assertIn ("params-decoder-dense_layers_0-mlp-wi_0-kernel" , mapping )
192153 # Layer 1 is MoE
193- self .assertIn (
194- "params-decoder-moe_layers_0-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel" ,
195- mapping ,
196- )
154+ self .assertIn ("params-decoder-moe_layers_0-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel" , mapping )
197155
198156 def test_gpt_oss_mapping (self ):
199157 config = {
200158 "num_hidden_layers" : 2 ,
201159 }
202160 maxtext_config = mock .Mock ()
203161 maxtext_config .inhomogeneous_layer_cycle_interval = 1
204- mapping = param_mapping .GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING (
205- config , maxtext_config , scan_layers = False
206- )
162+ mapping = param_mapping .GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
207163
208164 self .assertIn ("params-token_embedder-embedding" , mapping )
209- self .assertIn (
210- "params-decoder-layers_0-pre_self_attention_layer_norm-scale" , mapping
211- )
165+ self .assertIn ("params-decoder-layers_0-pre_self_attention_layer_norm-scale" , mapping )
212166
213167 def test_mixtral_mapping (self ):
214168 config = {
215169 "num_hidden_layers" : 2 ,
216170 }
217171 maxtext_config = mock .Mock ()
218172 maxtext_config .num_experts = 4
219- mapping = param_mapping .MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING (
220- config , maxtext_config , scan_layers = False
221- )
173+ mapping = param_mapping .MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
222174
223175 self .assertIn ("params-token_embedder-embedding" , mapping )
224176 self .assertIn ("params-decoder-layers_0-MoeBlock_0-gate-kernel" , mapping )
@@ -231,9 +183,7 @@ def test_gemma4_mapping(self):
231183 maxtext_config .share_kv_projections = False
232184 maxtext_config .use_multimodal = False
233185 maxtext_config .v_norm_with_scale = False
234- mapping = param_mapping .GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING (
235- config , maxtext_config , scan_layers = False
236- )
186+ mapping = param_mapping .GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
237187
238188 self .assertIn ("params-token_embedder-embedding" , mapping )
239189
@@ -244,9 +194,7 @@ def test_gemma2_hooks(self):
244194 "head_dim" : 64 ,
245195 }
246196 maxtext_config = mock .Mock ()
247- hooks = param_mapping .GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN (
248- config , maxtext_config , scan_layers = False , saving_to_hf = True
249- )
197+ hooks = param_mapping .GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = True )
250198
251199 self .assertIn ("params-token_embedder-embedding" , hooks )
252200
@@ -266,36 +214,20 @@ def test_gemma2_mapping_scanned(self):
266214 "hidden_size" : 256 ,
267215 }
268216 maxtext_config = mock .Mock ()
269- mapping = param_mapping .GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING (
270- config , maxtext_config , scan_layers = True
271- )
217+ mapping = param_mapping .GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
272218
273219 self .assertIn ("params-token_embedder-embedding" , mapping )
274- self .assertIn (
275- "params-decoder-layers-pre_self_attention_norm_local-scale" , mapping
276- )
277- self .assertIsInstance (
278- mapping ["params-decoder-layers-pre_self_attention_norm_local-scale" ],
279- list ,
280- )
281- self .assertEqual (
282- len (
283- mapping [
284- "params-decoder-layers-pre_self_attention_norm_local-scale"
285- ]
286- ),
287- 2 ,
288- )
220+ self .assertIn ("params-decoder-layers-pre_self_attention_norm_local-scale" , mapping )
221+ self .assertIsInstance (mapping ["params-decoder-layers-pre_self_attention_norm_local-scale" ], list )
222+ self .assertEqual (len (mapping ["params-decoder-layers-pre_self_attention_norm_local-scale" ]), 2 )
289223
290224 def test_qwen_hooks (self ):
291225 config = {
292226 "num_hidden_layers" : 2 ,
293227 "hidden_size" : 256 ,
294228 }
295229 maxtext_config = mock .Mock ()
296- hooks = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN (
297- config , maxtext_config , scan_layers = False , saving_to_hf = True
298- )
230+ hooks = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = True )
299231
300232 self .assertIn ("params-token_embedder-embedding" , hooks )
301233
@@ -312,25 +244,11 @@ def test_qwen_mapping_scanned(self):
312244 "hidden_size" : 256 ,
313245 }
314246 maxtext_config = mock .Mock ()
315- mapping = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (
316- config , maxtext_config , scan_layers = True
317- )
247+ mapping = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
318248
319- self .assertIn (
320- "params-decoder-layers-pre_self_attention_layer_norm-scale" , mapping
321- )
322- self .assertIsInstance (
323- mapping ["params-decoder-layers-pre_self_attention_layer_norm-scale" ],
324- list ,
325- )
326- self .assertEqual (
327- len (
328- mapping [
329- "params-decoder-layers-pre_self_attention_layer_norm-scale"
330- ]
331- ),
332- 4 ,
333- )
249+ self .assertIn ("params-decoder-layers-pre_self_attention_layer_norm-scale" , mapping )
250+ self .assertIsInstance (mapping ["params-decoder-layers-pre_self_attention_layer_norm-scale" ], list )
251+ self .assertEqual (len (mapping ["params-decoder-layers-pre_self_attention_layer_norm-scale" ]), 4 )
334252
335253 def test_deepseek_hooks (self ):
336254 config = {
@@ -352,17 +270,10 @@ def test_deepseek_mapping_scanned(self):
352270 "n_routed_experts" : 2 ,
353271 }
354272 maxtext_config = mock .Mock ()
355- mapping = param_mapping .DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING (
356- config , maxtext_config , scan_layers = True
357- )
273+ mapping = param_mapping .DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
358274
359- self .assertIn (
360- "params-decoder-dense_layers-self_attention-query-kernel" , mapping
361- )
362- self .assertIn (
363- "params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel" ,
364- mapping ,
365- )
275+ self .assertIn ("params-decoder-dense_layers-self_attention-query-kernel" , mapping )
276+ self .assertIn ("params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel" , mapping )
366277
367278 def test_gpt_oss_hooks (self ):
368279 config = {
@@ -371,9 +282,7 @@ def test_gpt_oss_hooks(self):
371282 }
372283 maxtext_config = mock .Mock ()
373284 maxtext_config .inhomogeneous_layer_cycle_interval = 1
374- hooks = param_mapping .GPT_OSS_TO_HF_PARAM_HOOK_FN (
375- config , maxtext_config , scan_layers = False , saving_to_hf = True
376- )
285+ hooks = param_mapping .GPT_OSS_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = True )
377286
378287 self .assertIn ("params-decoder-logits_dense-kernel" , hooks )
379288
@@ -383,14 +292,9 @@ def test_gpt_oss_mapping_scanned(self):
383292 }
384293 maxtext_config = mock .Mock ()
385294 maxtext_config .inhomogeneous_layer_cycle_interval = 2
386- mapping = param_mapping .GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING (
387- config , maxtext_config , scan_layers = True
388- )
295+ mapping = param_mapping .GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
389296
390- self .assertIn (
391- "params-decoder-layers-layers_0-pre_self_attention_layer_norm-scale" ,
392- mapping ,
393- )
297+ self .assertIn ("params-decoder-layers-layers_0-pre_self_attention_layer_norm-scale" , mapping )
394298
395299 def test_mixtral_hooks (self ):
396300 config = {
@@ -414,9 +318,7 @@ class Config:
414318 num_experts = 4
415319
416320 maxtext_config = Config ()
417- mapping = param_mapping .MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING (
418- config , maxtext_config , scan_layers = True
419- )
321+ mapping = param_mapping .MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
420322
421323 self .assertIn ("params-decoder-layers-self_attention-query-kernel" , mapping )
422324
@@ -428,14 +330,9 @@ def test_gemma4_mapping_scanned(self):
428330 maxtext_config .share_kv_projections = False
429331 maxtext_config .use_multimodal = False
430332 maxtext_config .v_norm_with_scale = False
431- mapping = param_mapping .GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING (
432- config , maxtext_config , scan_layers = True
433- )
333+ mapping = param_mapping .GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
434334
435- self .assertIn (
436- "params-decoder-scanned_blocks-layers_0-self_attention-query-kernel" ,
437- mapping ,
438- )
335+ self .assertIn ("params-decoder-scanned_blocks-layers_0-self_attention-query-kernel" , mapping )
439336
440337
441338if __name__ == "__main__" :
0 commit comments