@@ -29,15 +29,10 @@ def test_gemma3_mapping_unscanned(self):
2929 "vision_config" : {"num_hidden_layers" : 1 , "hidden_size" : 128 },
3030 }
3131 maxtext_config = mock .Mock ()
32- mapping = param_mapping .GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING (
33- config , maxtext_config , scan_layers = False
34- )
32+ mapping = param_mapping .GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
3533
3634 self .assertIn ("params-token_embedder-embedding" , mapping )
37- self .assertEqual (
38- mapping ["params-token_embedder-embedding" ],
39- "model.language_model.embed_tokens.weight" ,
40- )
35+ self .assertEqual (mapping ["params-token_embedder-embedding" ], "model.language_model.embed_tokens.weight" )
4136
4237 # Check text decoder layer 0
4338 self .assertIn ("params-decoder-layers_0-pre_self_attention_norm-scale" , mapping )
@@ -48,13 +43,10 @@ def test_gemma3_mapping_unscanned(self):
4843
4944 # Check vision encoder layer 0
5045 self .assertIn (
51- "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale" ,
52- mapping ,
46+ "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale" , mapping
5347 )
5448 self .assertEqual (
55- mapping [
56- "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale"
57- ],
49+ mapping ["params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale" ],
5850 "model.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight" ,
5951 )
6052
@@ -64,26 +56,15 @@ def test_gemma3_mapping_scanned(self):
6456 "vision_config" : {"num_hidden_layers" : 1 , "hidden_size" : 128 },
6557 }
6658 maxtext_config = mock .Mock ()
67- mapping = param_mapping .GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING (
68- config , maxtext_config , scan_layers = True
69- )
59+ mapping = param_mapping .GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
7060
7161 self .assertIn ("params-token_embedder-embedding" , mapping )
7262
7363 # Check scanned block mapping
74- self .assertIn (
75- "params-decoder-layers-layers_0-pre_self_attention_norm-scale" , mapping
76- )
77- self .assertIsInstance (
78- mapping ["params-decoder-layers-layers_0-pre_self_attention_norm-scale" ],
79- list ,
80- )
81- self .assertEqual (
82- len (
83- mapping ["params-decoder-layers-layers_0-pre_self_attention_norm-scale" ]
84- ),
85- 2 ,
86- )
64+ self .assertIn ("params-decoder-layers-layers_0-pre_self_attention_norm-scale" , mapping )
65+ self .assertIsInstance (mapping ["params-decoder-layers-layers_0-pre_self_attention_norm-scale" ], list )
66+ # Gemma3 repeats a 6-layer pattern. 12 layers means 2 of each.
67+ self .assertEqual (len (mapping ["params-decoder-layers-layers_0-pre_self_attention_norm-scale" ]), 2 )
8768
8869 def test_gemma3_hooks (self ):
8970 config = {
@@ -112,35 +93,25 @@ def test_gemma2_mapping(self):
11293 "hidden_size" : 256 ,
11394 }
11495 maxtext_config = mock .Mock ()
115- mapping = param_mapping .GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING (
116- config , maxtext_config , scan_layers = False
117- )
96+ mapping = param_mapping .GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
11897
11998 self .assertIn ("params-token_embedder-embedding" , mapping )
12099 # Gemma2 maps MaxText layer i to HF layers 2i and 2i+1
121- self .assertIn (
122- "params-decoder-layers_0-pre_self_attention_norm_local-scale" , mapping
123- )
100+ self .assertIn ("params-decoder-layers_0-pre_self_attention_norm_local-scale" , mapping )
124101 self .assertEqual (
125- mapping ["params-decoder-layers_0-pre_self_attention_norm_local-scale" ],
126- "model.layers.0.input_layernorm.weight" ,
127- )
128- self .assertIn (
129- "params-decoder-layers_0-pre_self_attention_norm_global-scale" , mapping
102+ mapping ["params-decoder-layers_0-pre_self_attention_norm_local-scale" ], "model.layers.0.input_layernorm.weight"
130103 )
104+ self .assertIn ("params-decoder-layers_0-pre_self_attention_norm_global-scale" , mapping )
131105 self .assertEqual (
132- mapping ["params-decoder-layers_0-pre_self_attention_norm_global-scale" ],
133- "model.layers.1.input_layernorm.weight" ,
106+ mapping ["params-decoder-layers_0-pre_self_attention_norm_global-scale" ], "model.layers.1.input_layernorm.weight"
134107 )
135108
136109 def test_qwen_mapping_dense (self ):
137110 config = {
138111 "num_hidden_layers" : 2 ,
139112 }
140113 maxtext_config = mock .Mock ()
141- mapping = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (
142- config , maxtext_config , scan_layers = False
143- )
114+ mapping = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
144115
145116 self .assertIn ("params-token_embedder-embedding" , mapping )
146117 self .assertIn ("params-decoder-layers_0-mlp-wi_0-kernel" , mapping )
@@ -151,14 +122,10 @@ def test_qwen_mapping_moe(self):
151122 "num_experts" : 4 ,
152123 }
153124 maxtext_config = mock .Mock ()
154- mapping = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (
155- config , maxtext_config , scan_layers = False
156- )
125+ mapping = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
157126
158127 self .assertIn ("params-decoder-layers_0-moe_block-wi_0" , mapping )
159- self .assertIsInstance (
160- mapping ["params-decoder-layers_0-moe_block-wi_0" ], list
161- )
128+ self .assertIsInstance (mapping ["params-decoder-layers_0-moe_block-wi_0" ], list )
162129 self .assertEqual (len (mapping ["params-decoder-layers_0-moe_block-wi_0" ]), 4 )
163130
164131 def test_qwen3_next_mapping (self ):
@@ -168,9 +135,7 @@ def test_qwen3_next_mapping(self):
168135 }
169136 maxtext_config = mock .Mock ()
170137 maxtext_config .inhomogeneous_layer_cycle_interval = 2
171- mapping = param_mapping .QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING (
172- config , maxtext_config , scan_layers = False
173- )
138+ mapping = param_mapping .QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
174139
175140 self .assertIn ("params-token_embedder-embedding" , mapping )
176141 self .assertIn ("params-decoder-layers_0-input_layernorm-scale" , mapping )
@@ -182,43 +147,32 @@ def test_deepseek_mapping(self):
182147 "n_routed_experts" : 2 ,
183148 }
184149 maxtext_config = mock .Mock ()
185- mapping = param_mapping .DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING (
186- config , maxtext_config , scan_layers = False
187- )
150+ mapping = param_mapping .DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
188151
189152 self .assertIn ("params-token_embedder-embedding" , mapping )
190153 # Layer 0 is dense
191154 self .assertIn ("params-decoder-dense_layers_0-mlp-wi_0-kernel" , mapping )
192155 # Layer 1 is MoE
193- self .assertIn (
194- "params-decoder-moe_layers_0-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel" ,
195- mapping ,
196- )
156+ self .assertIn ("params-decoder-moe_layers_0-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel" , mapping )
197157
198158 def test_gpt_oss_mapping (self ):
199159 config = {
200160 "num_hidden_layers" : 2 ,
201161 }
202162 maxtext_config = mock .Mock ()
203163 maxtext_config .inhomogeneous_layer_cycle_interval = 1
204- mapping = param_mapping .GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING (
205- config , maxtext_config , scan_layers = False
206- )
164+ mapping = param_mapping .GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
207165
208166 self .assertIn ("params-token_embedder-embedding" , mapping )
209- self .assertIn (
210- "params-decoder-layers_0-pre_self_attention_layer_norm-scale" , mapping
211- )
167+ self .assertIn ("params-decoder-layers_0-pre_self_attention_layer_norm-scale" , mapping )
212168
213169 def test_mixtral_mapping (self ):
214170 config = {
215171 "num_hidden_layers" : 2 ,
216172 }
217173 maxtext_config = mock .Mock ()
218174 maxtext_config .num_experts = 4
219- mapping = param_mapping .MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING (
220- config , maxtext_config , scan_layers = False
221- )
175+ mapping = param_mapping .MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
222176
223177 self .assertIn ("params-token_embedder-embedding" , mapping )
224178 self .assertIn ("params-decoder-layers_0-MoeBlock_0-gate-kernel" , mapping )
@@ -231,9 +185,7 @@ def test_gemma4_mapping(self):
231185 maxtext_config .share_kv_projections = False
232186 maxtext_config .use_multimodal = False
233187 maxtext_config .v_norm_with_scale = False
234- mapping = param_mapping .GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING (
235- config , maxtext_config , scan_layers = False
236- )
188+ mapping = param_mapping .GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False )
237189
238190 self .assertIn ("params-token_embedder-embedding" , mapping )
239191
@@ -244,9 +196,7 @@ def test_gemma2_hooks(self):
244196 "head_dim" : 64 ,
245197 }
246198 maxtext_config = mock .Mock ()
247- hooks = param_mapping .GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN (
248- config , maxtext_config , scan_layers = False , saving_to_hf = True
249- )
199+ hooks = param_mapping .GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = True )
250200
251201 self .assertIn ("params-token_embedder-embedding" , hooks )
252202
@@ -266,36 +216,20 @@ def test_gemma2_mapping_scanned(self):
266216 "hidden_size" : 256 ,
267217 }
268218 maxtext_config = mock .Mock ()
269- mapping = param_mapping .GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING (
270- config , maxtext_config , scan_layers = True
271- )
219+ mapping = param_mapping .GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
272220
273221 self .assertIn ("params-token_embedder-embedding" , mapping )
274- self .assertIn (
275- "params-decoder-layers-pre_self_attention_norm_local-scale" , mapping
276- )
277- self .assertIsInstance (
278- mapping ["params-decoder-layers-pre_self_attention_norm_local-scale" ],
279- list ,
280- )
281- self .assertEqual (
282- len (
283- mapping [
284- "params-decoder-layers-pre_self_attention_norm_local-scale"
285- ]
286- ),
287- 2 ,
288- )
222+ self .assertIn ("params-decoder-layers-pre_self_attention_norm_local-scale" , mapping )
223+ self .assertIsInstance (mapping ["params-decoder-layers-pre_self_attention_norm_local-scale" ], list )
224+ self .assertEqual (len (mapping ["params-decoder-layers-pre_self_attention_norm_local-scale" ]), 2 )
289225
290226 def test_qwen_hooks (self ):
291227 config = {
292228 "num_hidden_layers" : 2 ,
293229 "hidden_size" : 256 ,
294230 }
295231 maxtext_config = mock .Mock ()
296- hooks = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN (
297- config , maxtext_config , scan_layers = False , saving_to_hf = True
298- )
232+ hooks = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = True )
299233
300234 self .assertIn ("params-token_embedder-embedding" , hooks )
301235
@@ -312,25 +246,11 @@ def test_qwen_mapping_scanned(self):
312246 "hidden_size" : 256 ,
313247 }
314248 maxtext_config = mock .Mock ()
315- mapping = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (
316- config , maxtext_config , scan_layers = True
317- )
249+ mapping = param_mapping .QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
318250
319- self .assertIn (
320- "params-decoder-layers-pre_self_attention_layer_norm-scale" , mapping
321- )
322- self .assertIsInstance (
323- mapping ["params-decoder-layers-pre_self_attention_layer_norm-scale" ],
324- list ,
325- )
326- self .assertEqual (
327- len (
328- mapping [
329- "params-decoder-layers-pre_self_attention_layer_norm-scale"
330- ]
331- ),
332- 4 ,
333- )
251+ self .assertIn ("params-decoder-layers-pre_self_attention_layer_norm-scale" , mapping )
252+ self .assertIsInstance (mapping ["params-decoder-layers-pre_self_attention_layer_norm-scale" ], list )
253+ self .assertEqual (len (mapping ["params-decoder-layers-pre_self_attention_layer_norm-scale" ]), 4 )
334254
335255 def test_deepseek_hooks (self ):
336256 config = {
@@ -352,17 +272,10 @@ def test_deepseek_mapping_scanned(self):
352272 "n_routed_experts" : 2 ,
353273 }
354274 maxtext_config = mock .Mock ()
355- mapping = param_mapping .DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING (
356- config , maxtext_config , scan_layers = True
357- )
275+ mapping = param_mapping .DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
358276
359- self .assertIn (
360- "params-decoder-dense_layers-self_attention-query-kernel" , mapping
361- )
362- self .assertIn (
363- "params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel" ,
364- mapping ,
365- )
277+ self .assertIn ("params-decoder-dense_layers-self_attention-query-kernel" , mapping )
278+ self .assertIn ("params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel" , mapping )
366279
367280 def test_gpt_oss_hooks (self ):
368281 config = {
@@ -371,9 +284,7 @@ def test_gpt_oss_hooks(self):
371284 }
372285 maxtext_config = mock .Mock ()
373286 maxtext_config .inhomogeneous_layer_cycle_interval = 1
374- hooks = param_mapping .GPT_OSS_TO_HF_PARAM_HOOK_FN (
375- config , maxtext_config , scan_layers = False , saving_to_hf = True
376- )
287+ hooks = param_mapping .GPT_OSS_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = True )
377288
378289 self .assertIn ("params-decoder-logits_dense-kernel" , hooks )
379290
@@ -383,14 +294,9 @@ def test_gpt_oss_mapping_scanned(self):
383294 }
384295 maxtext_config = mock .Mock ()
385296 maxtext_config .inhomogeneous_layer_cycle_interval = 2
386- mapping = param_mapping .GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING (
387- config , maxtext_config , scan_layers = True
388- )
297+ mapping = param_mapping .GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
389298
390- self .assertIn (
391- "params-decoder-layers-layers_0-pre_self_attention_layer_norm-scale" ,
392- mapping ,
393- )
299+ self .assertIn ("params-decoder-layers-layers_0-pre_self_attention_layer_norm-scale" , mapping )
394300
395301 def test_mixtral_hooks (self ):
396302 config = {
@@ -399,9 +305,7 @@ def test_mixtral_hooks(self):
399305 }
400306 maxtext_config = mock .Mock ()
401307 maxtext_config .head_dim = 64
402- hooks = param_mapping .MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN (
403- config , maxtext_config , scan_layers = False , saving_to_hf = True
404- )
308+ hooks = param_mapping .MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = True )
405309
406310 self .assertIn ("params-decoder-logits_dense-kernel" , hooks )
407311
@@ -414,9 +318,7 @@ class Config:
414318 num_experts = 4
415319
416320 maxtext_config = Config ()
417- mapping = param_mapping .MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING (
418- config , maxtext_config , scan_layers = True
419- )
321+ mapping = param_mapping .MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
420322
421323 self .assertIn ("params-decoder-layers-self_attention-query-kernel" , mapping )
422324
@@ -428,14 +330,9 @@ def test_gemma4_mapping_scanned(self):
428330 maxtext_config .share_kv_projections = False
429331 maxtext_config .use_multimodal = False
430332 maxtext_config .v_norm_with_scale = False
431- mapping = param_mapping .GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING (
432- config , maxtext_config , scan_layers = True
433- )
333+ mapping = param_mapping .GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = True )
434334
435- self .assertIn (
436- "params-decoder-scanned_blocks-layers_0-self_attention-query-kernel" ,
437- mapping ,
438- )
335+ self .assertIn ("params-decoder-scanned_blocks-layers_0-self_attention-query-kernel" , mapping )
439336
440337
441338if __name__ == "__main__" :
0 commit comments