Skip to content

Commit 1f1865f

Browse files
committed
Reformat with pyink indentation=2 line-length=122
1 parent 5c832f5 commit 1f1865f

1 file changed

Lines changed: 44 additions & 147 deletions

File tree

tests/unit/param_mapping_test.py

Lines changed: 44 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,10 @@ def test_gemma3_mapping_unscanned(self):
2929
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
3030
}
3131
maxtext_config = mock.Mock()
32-
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(
33-
config, maxtext_config, scan_layers=False
34-
)
32+
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
3533

3634
self.assertIn("params-token_embedder-embedding", mapping)
37-
self.assertEqual(
38-
mapping["params-token_embedder-embedding"],
39-
"model.language_model.embed_tokens.weight",
40-
)
35+
self.assertEqual(mapping["params-token_embedder-embedding"], "model.language_model.embed_tokens.weight")
4136

4237
# Check text decoder layer 0
4338
self.assertIn("params-decoder-layers_0-pre_self_attention_norm-scale", mapping)
@@ -48,13 +43,10 @@ def test_gemma3_mapping_unscanned(self):
4843

4944
# Check vision encoder layer 0
5045
self.assertIn(
51-
"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale",
52-
mapping,
46+
"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale", mapping
5347
)
5448
self.assertEqual(
55-
mapping[
56-
"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale"
57-
],
49+
mapping["params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_0-LayerNorm_0-scale"],
5850
"model.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight",
5951
)
6052

@@ -64,26 +56,15 @@ def test_gemma3_mapping_scanned(self):
6456
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
6557
}
6658
maxtext_config = mock.Mock()
67-
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(
68-
config, maxtext_config, scan_layers=True
69-
)
59+
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
7060

7161
self.assertIn("params-token_embedder-embedding", mapping)
7262

7363
# Check scanned block mapping
74-
self.assertIn(
75-
"params-decoder-layers-layers_0-pre_self_attention_norm-scale", mapping
76-
)
77-
self.assertIsInstance(
78-
mapping["params-decoder-layers-layers_0-pre_self_attention_norm-scale"],
79-
list,
80-
)
81-
self.assertEqual(
82-
len(
83-
mapping["params-decoder-layers-layers_0-pre_self_attention_norm-scale"]
84-
),
85-
2,
86-
)
64+
self.assertIn("params-decoder-layers-layers_0-pre_self_attention_norm-scale", mapping)
65+
self.assertIsInstance(mapping["params-decoder-layers-layers_0-pre_self_attention_norm-scale"], list)
66+
# Gemma3 repeats a 6-layer pattern. 12 layers means 2 of each.
67+
self.assertEqual(len(mapping["params-decoder-layers-layers_0-pre_self_attention_norm-scale"]), 2)
8768

8869
def test_gemma3_hooks(self):
8970
config = {
@@ -112,35 +93,25 @@ def test_gemma2_mapping(self):
11293
"hidden_size": 256,
11394
}
11495
maxtext_config = mock.Mock()
115-
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(
116-
config, maxtext_config, scan_layers=False
117-
)
96+
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
11897

11998
self.assertIn("params-token_embedder-embedding", mapping)
12099
# Gemma2 maps MaxText layer i to HF layers 2i and 2i+1
121-
self.assertIn(
122-
"params-decoder-layers_0-pre_self_attention_norm_local-scale", mapping
123-
)
100+
self.assertIn("params-decoder-layers_0-pre_self_attention_norm_local-scale", mapping)
124101
self.assertEqual(
125-
mapping["params-decoder-layers_0-pre_self_attention_norm_local-scale"],
126-
"model.layers.0.input_layernorm.weight",
127-
)
128-
self.assertIn(
129-
"params-decoder-layers_0-pre_self_attention_norm_global-scale", mapping
102+
mapping["params-decoder-layers_0-pre_self_attention_norm_local-scale"], "model.layers.0.input_layernorm.weight"
130103
)
104+
self.assertIn("params-decoder-layers_0-pre_self_attention_norm_global-scale", mapping)
131105
self.assertEqual(
132-
mapping["params-decoder-layers_0-pre_self_attention_norm_global-scale"],
133-
"model.layers.1.input_layernorm.weight",
106+
mapping["params-decoder-layers_0-pre_self_attention_norm_global-scale"], "model.layers.1.input_layernorm.weight"
134107
)
135108

136109
def test_qwen_mapping_dense(self):
137110
config = {
138111
"num_hidden_layers": 2,
139112
}
140113
maxtext_config = mock.Mock()
141-
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(
142-
config, maxtext_config, scan_layers=False
143-
)
114+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
144115

145116
self.assertIn("params-token_embedder-embedding", mapping)
146117
self.assertIn("params-decoder-layers_0-mlp-wi_0-kernel", mapping)
@@ -151,14 +122,10 @@ def test_qwen_mapping_moe(self):
151122
"num_experts": 4,
152123
}
153124
maxtext_config = mock.Mock()
154-
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(
155-
config, maxtext_config, scan_layers=False
156-
)
125+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
157126

158127
self.assertIn("params-decoder-layers_0-moe_block-wi_0", mapping)
159-
self.assertIsInstance(
160-
mapping["params-decoder-layers_0-moe_block-wi_0"], list
161-
)
128+
self.assertIsInstance(mapping["params-decoder-layers_0-moe_block-wi_0"], list)
162129
self.assertEqual(len(mapping["params-decoder-layers_0-moe_block-wi_0"]), 4)
163130

164131
def test_qwen3_next_mapping(self):
@@ -168,9 +135,7 @@ def test_qwen3_next_mapping(self):
168135
}
169136
maxtext_config = mock.Mock()
170137
maxtext_config.inhomogeneous_layer_cycle_interval = 2
171-
mapping = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(
172-
config, maxtext_config, scan_layers=False
173-
)
138+
mapping = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
174139

175140
self.assertIn("params-token_embedder-embedding", mapping)
176141
self.assertIn("params-decoder-layers_0-input_layernorm-scale", mapping)
@@ -182,43 +147,32 @@ def test_deepseek_mapping(self):
182147
"n_routed_experts": 2,
183148
}
184149
maxtext_config = mock.Mock()
185-
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(
186-
config, maxtext_config, scan_layers=False
187-
)
150+
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
188151

189152
self.assertIn("params-token_embedder-embedding", mapping)
190153
# Layer 0 is dense
191154
self.assertIn("params-decoder-dense_layers_0-mlp-wi_0-kernel", mapping)
192155
# Layer 1 is MoE
193-
self.assertIn(
194-
"params-decoder-moe_layers_0-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel",
195-
mapping,
196-
)
156+
self.assertIn("params-decoder-moe_layers_0-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel", mapping)
197157

198158
def test_gpt_oss_mapping(self):
199159
config = {
200160
"num_hidden_layers": 2,
201161
}
202162
maxtext_config = mock.Mock()
203163
maxtext_config.inhomogeneous_layer_cycle_interval = 1
204-
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(
205-
config, maxtext_config, scan_layers=False
206-
)
164+
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
207165

208166
self.assertIn("params-token_embedder-embedding", mapping)
209-
self.assertIn(
210-
"params-decoder-layers_0-pre_self_attention_layer_norm-scale", mapping
211-
)
167+
self.assertIn("params-decoder-layers_0-pre_self_attention_layer_norm-scale", mapping)
212168

213169
def test_mixtral_mapping(self):
214170
config = {
215171
"num_hidden_layers": 2,
216172
}
217173
maxtext_config = mock.Mock()
218174
maxtext_config.num_experts = 4
219-
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(
220-
config, maxtext_config, scan_layers=False
221-
)
175+
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
222176

223177
self.assertIn("params-token_embedder-embedding", mapping)
224178
self.assertIn("params-decoder-layers_0-MoeBlock_0-gate-kernel", mapping)
@@ -231,9 +185,7 @@ def test_gemma4_mapping(self):
231185
maxtext_config.share_kv_projections = False
232186
maxtext_config.use_multimodal = False
233187
maxtext_config.v_norm_with_scale = False
234-
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(
235-
config, maxtext_config, scan_layers=False
236-
)
188+
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
237189

238190
self.assertIn("params-token_embedder-embedding", mapping)
239191

@@ -244,9 +196,7 @@ def test_gemma2_hooks(self):
244196
"head_dim": 64,
245197
}
246198
maxtext_config = mock.Mock()
247-
hooks = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(
248-
config, maxtext_config, scan_layers=False, saving_to_hf=True
249-
)
199+
hooks = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
250200

251201
self.assertIn("params-token_embedder-embedding", hooks)
252202

@@ -266,36 +216,20 @@ def test_gemma2_mapping_scanned(self):
266216
"hidden_size": 256,
267217
}
268218
maxtext_config = mock.Mock()
269-
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(
270-
config, maxtext_config, scan_layers=True
271-
)
219+
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
272220

273221
self.assertIn("params-token_embedder-embedding", mapping)
274-
self.assertIn(
275-
"params-decoder-layers-pre_self_attention_norm_local-scale", mapping
276-
)
277-
self.assertIsInstance(
278-
mapping["params-decoder-layers-pre_self_attention_norm_local-scale"],
279-
list,
280-
)
281-
self.assertEqual(
282-
len(
283-
mapping[
284-
"params-decoder-layers-pre_self_attention_norm_local-scale"
285-
]
286-
),
287-
2,
288-
)
222+
self.assertIn("params-decoder-layers-pre_self_attention_norm_local-scale", mapping)
223+
self.assertIsInstance(mapping["params-decoder-layers-pre_self_attention_norm_local-scale"], list)
224+
self.assertEqual(len(mapping["params-decoder-layers-pre_self_attention_norm_local-scale"]), 2)
289225

290226
def test_qwen_hooks(self):
291227
config = {
292228
"num_hidden_layers": 2,
293229
"hidden_size": 256,
294230
}
295231
maxtext_config = mock.Mock()
296-
hooks = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(
297-
config, maxtext_config, scan_layers=False, saving_to_hf=True
298-
)
232+
hooks = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
299233

300234
self.assertIn("params-token_embedder-embedding", hooks)
301235

@@ -312,25 +246,11 @@ def test_qwen_mapping_scanned(self):
312246
"hidden_size": 256,
313247
}
314248
maxtext_config = mock.Mock()
315-
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(
316-
config, maxtext_config, scan_layers=True
317-
)
249+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
318250

319-
self.assertIn(
320-
"params-decoder-layers-pre_self_attention_layer_norm-scale", mapping
321-
)
322-
self.assertIsInstance(
323-
mapping["params-decoder-layers-pre_self_attention_layer_norm-scale"],
324-
list,
325-
)
326-
self.assertEqual(
327-
len(
328-
mapping[
329-
"params-decoder-layers-pre_self_attention_layer_norm-scale"
330-
]
331-
),
332-
4,
333-
)
251+
self.assertIn("params-decoder-layers-pre_self_attention_layer_norm-scale", mapping)
252+
self.assertIsInstance(mapping["params-decoder-layers-pre_self_attention_layer_norm-scale"], list)
253+
self.assertEqual(len(mapping["params-decoder-layers-pre_self_attention_layer_norm-scale"]), 4)
334254

335255
def test_deepseek_hooks(self):
336256
config = {
@@ -352,17 +272,10 @@ def test_deepseek_mapping_scanned(self):
352272
"n_routed_experts": 2,
353273
}
354274
maxtext_config = mock.Mock()
355-
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(
356-
config, maxtext_config, scan_layers=True
357-
)
275+
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
358276

359-
self.assertIn(
360-
"params-decoder-dense_layers-self_attention-query-kernel", mapping
361-
)
362-
self.assertIn(
363-
"params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel",
364-
mapping,
365-
)
277+
self.assertIn("params-decoder-dense_layers-self_attention-query-kernel", mapping)
278+
self.assertIn("params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel", mapping)
366279

367280
def test_gpt_oss_hooks(self):
368281
config = {
@@ -371,9 +284,7 @@ def test_gpt_oss_hooks(self):
371284
}
372285
maxtext_config = mock.Mock()
373286
maxtext_config.inhomogeneous_layer_cycle_interval = 1
374-
hooks = param_mapping.GPT_OSS_TO_HF_PARAM_HOOK_FN(
375-
config, maxtext_config, scan_layers=False, saving_to_hf=True
376-
)
287+
hooks = param_mapping.GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
377288

378289
self.assertIn("params-decoder-logits_dense-kernel", hooks)
379290

@@ -383,14 +294,9 @@ def test_gpt_oss_mapping_scanned(self):
383294
}
384295
maxtext_config = mock.Mock()
385296
maxtext_config.inhomogeneous_layer_cycle_interval = 2
386-
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(
387-
config, maxtext_config, scan_layers=True
388-
)
297+
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
389298

390-
self.assertIn(
391-
"params-decoder-layers-layers_0-pre_self_attention_layer_norm-scale",
392-
mapping,
393-
)
299+
self.assertIn("params-decoder-layers-layers_0-pre_self_attention_layer_norm-scale", mapping)
394300

395301
def test_mixtral_hooks(self):
396302
config = {
@@ -399,9 +305,7 @@ def test_mixtral_hooks(self):
399305
}
400306
maxtext_config = mock.Mock()
401307
maxtext_config.head_dim = 64
402-
hooks = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN(
403-
config, maxtext_config, scan_layers=False, saving_to_hf=True
404-
)
308+
hooks = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
405309

406310
self.assertIn("params-decoder-logits_dense-kernel", hooks)
407311

@@ -414,9 +318,7 @@ class Config:
414318
num_experts = 4
415319

416320
maxtext_config = Config()
417-
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(
418-
config, maxtext_config, scan_layers=True
419-
)
321+
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
420322

421323
self.assertIn("params-decoder-layers-self_attention-query-kernel", mapping)
422324

@@ -428,14 +330,9 @@ def test_gemma4_mapping_scanned(self):
428330
maxtext_config.share_kv_projections = False
429331
maxtext_config.use_multimodal = False
430332
maxtext_config.v_norm_with_scale = False
431-
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(
432-
config, maxtext_config, scan_layers=True
433-
)
333+
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
434334

435-
self.assertIn(
436-
"params-decoder-scanned_blocks-layers_0-self_attention-query-kernel",
437-
mapping,
438-
)
335+
self.assertIn("params-decoder-scanned_blocks-layers_0-self_attention-query-kernel", mapping)
439336

440337

441338
if __name__ == "__main__":

0 commit comments

Comments
 (0)