Skip to content

Commit 94949ed

Browse files
committed
Reformat with pyink indentation=2 line-length=122
1 parent 5c832f5 commit 94949ed

1 file changed

Lines changed: 44 additions & 147 deletions

File tree

tests/unit/param_mapping_test.py

Lines changed: 44 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,10 @@ def test_gemma3_mapping_unscanned(self):
2929
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
3030
}
3131
maxtext_config = mock.Mock()
32-
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(
33-
config, maxtext_config, scan_layers=False
34-
)
32+
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
3533

3634
self.assertIn("params-token_embedder-embedding", mapping)
37-
self.assertEqual(
38-
mapping["params-token_embedder-embedding"],
39-
"model.language_model.embed_tokens.weight",
40-
)
35+
self.assertEqual(mapping["params-token_embedder-embedding"], "model.language_model.embed_tokens.weight")
4136

4237
# Check text decoder layer 0
4338
self.assertIn("params-decoder-layers_0-pre_self_attention_norm-scale", mapping)
@@ -48,13 +43,10 @@ def test_gemma3_mapping_unscanned(self):
4843

4944
# Check vision encoder layer 0
5045
self.assertIn(
51-
"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale",
52-
mapping,
46+
"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale", mapping
5347
)
5448
self.assertEqual(
55-
mapping[
56-
"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale"
57-
],
49+
mapping["params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale"],
5850
"model.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight",
5951
)
6052

@@ -64,36 +56,23 @@ def test_gemma3_mapping_scanned(self):
6456
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
6557
}
6658
maxtext_config = mock.Mock()
67-
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(
68-
config, maxtext_config, scan_layers=True
69-
)
59+
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
7060

7161
self.assertIn("params-token_embedder-embedding", mapping)
7262

7363
# Check scanned block mapping
74-
self.assertIn(
75-
"params-decoder-layers-layers_0-pre_self_attention_norm-scale", mapping
76-
)
77-
self.assertIsInstance(
78-
mapping["params-decoder-layers-layers_0-pre_self_attention_norm-scale"],
79-
list,
80-
)
81-
self.assertEqual(
82-
len(
83-
mapping["params-decoder-layers-layers_0-pre_self_attention_norm-scale"]
84-
),
85-
2,
86-
)
64+
self.assertIn("params-decoder-layers-layers_0-pre_self_attention_norm-scale", mapping)
65+
self.assertIsInstance(mapping["params-decoder-layers-layers_0-pre_self_attention_norm-scale"], list)
66+
# Gemma3 repeats a 6-layer pattern. 12 layers means 2 of each.
67+
self.assertEqual(len(mapping["params-decoder-layers-layers_0-pre_self_attention_norm-scale"]), 2)
8768

8869
def test_gemma3_hooks(self):
8970
config = {
9071
"text_config": {"num_hidden_layers": 2, "hidden_size": 256},
9172
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
9273
}
9374
maxtext_config = mock.Mock()
94-
hooks = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(
95-
config, maxtext_config, scan_layers=False, saving_to_hf=True
96-
)
75+
hooks = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
9776

9877
self.assertIn("params-token_embedder-embedding", hooks)
9978

@@ -112,35 +91,25 @@ def test_gemma2_mapping(self):
11291
"hidden_size": 256,
11392
}
11493
maxtext_config = mock.Mock()
115-
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(
116-
config, maxtext_config, scan_layers=False
117-
)
94+
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
11895

11996
self.assertIn("params-token_embedder-embedding", mapping)
12097
# Gemma2 maps MaxText layer i to HF layers 2i and 2i+1
121-
self.assertIn(
122-
"params-decoder-layers_0-pre_self_attention_norm_local-scale", mapping
123-
)
98+
self.assertIn("params-decoder-layers_0-pre_self_attention_norm_local-scale", mapping)
12499
self.assertEqual(
125-
mapping["params-decoder-layers_0-pre_self_attention_norm_local-scale"],
126-
"model.layers.0.input_layernorm.weight",
127-
)
128-
self.assertIn(
129-
"params-decoder-layers_0-pre_self_attention_norm_global-scale", mapping
100+
mapping["params-decoder-layers_0-pre_self_attention_norm_local-scale"], "model.layers.0.input_layernorm.weight"
130101
)
102+
self.assertIn("params-decoder-layers_0-pre_self_attention_norm_global-scale", mapping)
131103
self.assertEqual(
132-
mapping["params-decoder-layers_0-pre_self_attention_norm_global-scale"],
133-
"model.layers.1.input_layernorm.weight",
104+
mapping["params-decoder-layers_0-pre_self_attention_norm_global-scale"], "model.layers.1.input_layernorm.weight"
134105
)
135106

136107
def test_qwen_mapping_dense(self):
137108
config = {
138109
"num_hidden_layers": 2,
139110
}
140111
maxtext_config = mock.Mock()
141-
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(
142-
config, maxtext_config, scan_layers=False
143-
)
112+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
144113

145114
self.assertIn("params-token_embedder-embedding", mapping)
146115
self.assertIn("params-decoder-layers_0-mlp-wi_0-kernel", mapping)
@@ -151,14 +120,10 @@ def test_qwen_mapping_moe(self):
151120
"num_experts": 4,
152121
}
153122
maxtext_config = mock.Mock()
154-
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(
155-
config, maxtext_config, scan_layers=False
156-
)
123+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
157124

158125
self.assertIn("params-decoder-layers_0-moe_block-wi_0", mapping)
159-
self.assertIsInstance(
160-
mapping["params-decoder-layers_0-moe_block-wi_0"], list
161-
)
126+
self.assertIsInstance(mapping["params-decoder-layers_0-moe_block-wi_0"], list)
162127
self.assertEqual(len(mapping["params-decoder-layers_0-moe_block-wi_0"]), 4)
163128

164129
def test_qwen3_next_mapping(self):
@@ -168,9 +133,7 @@ def test_qwen3_next_mapping(self):
168133
}
169134
maxtext_config = mock.Mock()
170135
maxtext_config.inhomogeneous_layer_cycle_interval = 2
171-
mapping = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(
172-
config, maxtext_config, scan_layers=False
173-
)
136+
mapping = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
174137

175138
self.assertIn("params-token_embedder-embedding", mapping)
176139
self.assertIn("params-decoder-layers_0-input_layernorm-scale", mapping)
@@ -182,43 +145,32 @@ def test_deepseek_mapping(self):
182145
"n_routed_experts": 2,
183146
}
184147
maxtext_config = mock.Mock()
185-
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(
186-
config, maxtext_config, scan_layers=False
187-
)
148+
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
188149

189150
self.assertIn("params-token_embedder-embedding", mapping)
190151
# Layer 0 is dense
191152
self.assertIn("params-decoder-dense_layers_0-mlp-wi_0-kernel", mapping)
192153
# Layer 1 is MoE
193-
self.assertIn(
194-
"params-decoder-moe_layers_0-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel",
195-
mapping,
196-
)
154+
self.assertIn("params-decoder-moe_layers_0-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel", mapping)
197155

198156
def test_gpt_oss_mapping(self):
199157
config = {
200158
"num_hidden_layers": 2,
201159
}
202160
maxtext_config = mock.Mock()
203161
maxtext_config.inhomogeneous_layer_cycle_interval = 1
204-
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(
205-
config, maxtext_config, scan_layers=False
206-
)
162+
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
207163

208164
self.assertIn("params-token_embedder-embedding", mapping)
209-
self.assertIn(
210-
"params-decoder-layers_0-pre_self_attention_layer_norm-scale", mapping
211-
)
165+
self.assertIn("params-decoder-layers_0-pre_self_attention_layer_norm-scale", mapping)
212166

213167
def test_mixtral_mapping(self):
214168
config = {
215169
"num_hidden_layers": 2,
216170
}
217171
maxtext_config = mock.Mock()
218172
maxtext_config.num_experts = 4
219-
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(
220-
config, maxtext_config, scan_layers=False
221-
)
173+
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
222174

223175
self.assertIn("params-token_embedder-embedding", mapping)
224176
self.assertIn("params-decoder-layers_0-MoeBlock_0-gate-kernel", mapping)
@@ -231,9 +183,7 @@ def test_gemma4_mapping(self):
231183
maxtext_config.share_kv_projections = False
232184
maxtext_config.use_multimodal = False
233185
maxtext_config.v_norm_with_scale = False
234-
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(
235-
config, maxtext_config, scan_layers=False
236-
)
186+
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
237187

238188
self.assertIn("params-token_embedder-embedding", mapping)
239189

@@ -244,9 +194,7 @@ def test_gemma2_hooks(self):
244194
"head_dim": 64,
245195
}
246196
maxtext_config = mock.Mock()
247-
hooks = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(
248-
config, maxtext_config, scan_layers=False, saving_to_hf=True
249-
)
197+
hooks = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
250198

251199
self.assertIn("params-token_embedder-embedding", hooks)
252200

@@ -266,36 +214,20 @@ def test_gemma2_mapping_scanned(self):
266214
"hidden_size": 256,
267215
}
268216
maxtext_config = mock.Mock()
269-
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(
270-
config, maxtext_config, scan_layers=True
271-
)
217+
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
272218

273219
self.assertIn("params-token_embedder-embedding", mapping)
274-
self.assertIn(
275-
"params-decoder-layers-pre_self_attention_norm_local-scale", mapping
276-
)
277-
self.assertIsInstance(
278-
mapping["params-decoder-layers-pre_self_attention_norm_local-scale"],
279-
list,
280-
)
281-
self.assertEqual(
282-
len(
283-
mapping[
284-
"params-decoder-layers-pre_self_attention_norm_local-scale"
285-
]
286-
),
287-
2,
288-
)
220+
self.assertIn("params-decoder-layers-pre_self_attention_norm_local-scale", mapping)
221+
self.assertIsInstance(mapping["params-decoder-layers-pre_self_attention_norm_local-scale"], list)
222+
self.assertEqual(len(mapping["params-decoder-layers-pre_self_attention_norm_local-scale"]), 2)
289223

290224
def test_qwen_hooks(self):
291225
config = {
292226
"num_hidden_layers": 2,
293227
"hidden_size": 256,
294228
}
295229
maxtext_config = mock.Mock()
296-
hooks = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(
297-
config, maxtext_config, scan_layers=False, saving_to_hf=True
298-
)
230+
hooks = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
299231

300232
self.assertIn("params-token_embedder-embedding", hooks)
301233

@@ -312,25 +244,11 @@ def test_qwen_mapping_scanned(self):
312244
"hidden_size": 256,
313245
}
314246
maxtext_config = mock.Mock()
315-
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(
316-
config, maxtext_config, scan_layers=True
317-
)
247+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
318248

319-
self.assertIn(
320-
"params-decoder-layers-pre_self_attention_layer_norm-scale", mapping
321-
)
322-
self.assertIsInstance(
323-
mapping["params-decoder-layers-pre_self_attention_layer_norm-scale"],
324-
list,
325-
)
326-
self.assertEqual(
327-
len(
328-
mapping[
329-
"params-decoder-layers-pre_self_attention_layer_norm-scale"
330-
]
331-
),
332-
4,
333-
)
249+
self.assertIn("params-decoder-layers-pre_self_attention_layer_norm-scale", mapping)
250+
self.assertIsInstance(mapping["params-decoder-layers-pre_self_attention_layer_norm-scale"], list)
251+
self.assertEqual(len(mapping["params-decoder-layers-pre_self_attention_layer_norm-scale"]), 4)
334252

335253
def test_deepseek_hooks(self):
336254
config = {
@@ -352,17 +270,10 @@ def test_deepseek_mapping_scanned(self):
352270
"n_routed_experts": 2,
353271
}
354272
maxtext_config = mock.Mock()
355-
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(
356-
config, maxtext_config, scan_layers=True
357-
)
273+
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
358274

359-
self.assertIn(
360-
"params-decoder-dense_layers-self_attention-query-kernel", mapping
361-
)
362-
self.assertIn(
363-
"params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel",
364-
mapping,
365-
)
275+
self.assertIn("params-decoder-dense_layers-self_attention-query-kernel", mapping)
276+
self.assertIn("params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel", mapping)
366277

367278
def test_gpt_oss_hooks(self):
368279
config = {
@@ -371,9 +282,7 @@ def test_gpt_oss_hooks(self):
371282
}
372283
maxtext_config = mock.Mock()
373284
maxtext_config.inhomogeneous_layer_cycle_interval = 1
374-
hooks = param_mapping.GPT_OSS_TO_HF_PARAM_HOOK_FN(
375-
config, maxtext_config, scan_layers=False, saving_to_hf=True
376-
)
285+
hooks = param_mapping.GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
377286

378287
self.assertIn("params-decoder-logits_dense-kernel", hooks)
379288

@@ -383,14 +292,9 @@ def test_gpt_oss_mapping_scanned(self):
383292
}
384293
maxtext_config = mock.Mock()
385294
maxtext_config.inhomogeneous_layer_cycle_interval = 2
386-
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(
387-
config, maxtext_config, scan_layers=True
388-
)
295+
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
389296

390-
self.assertIn(
391-
"params-decoder-layers-layers_0-pre_self_attention_layer_norm-scale",
392-
mapping,
393-
)
297+
self.assertIn("params-decoder-layers-layers_0-pre_self_attention_layer_norm-scale", mapping)
394298

395299
def test_mixtral_hooks(self):
396300
config = {
@@ -414,9 +318,7 @@ class Config:
414318
num_experts = 4
415319

416320
maxtext_config = Config()
417-
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(
418-
config, maxtext_config, scan_layers=True
419-
)
321+
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
420322

421323
self.assertIn("params-decoder-layers-self_attention-query-kernel", mapping)
422324

@@ -428,14 +330,9 @@ def test_gemma4_mapping_scanned(self):
428330
maxtext_config.share_kv_projections = False
429331
maxtext_config.use_multimodal = False
430332
maxtext_config.v_norm_with_scale = False
431-
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(
432-
config, maxtext_config, scan_layers=True
433-
)
333+
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
434334

435-
self.assertIn(
436-
"params-decoder-scanned_blocks-layers_0-self_attention-query-kernel",
437-
mapping,
438-
)
335+
self.assertIn("params-decoder-scanned_blocks-layers_0-self_attention-query-kernel", mapping)
439336

440337

441338
if __name__ == "__main__":

0 commit comments

Comments
 (0)