Skip to content

Commit ff17127

Browse files
committed
Add unit tests for param_mapping.py
1 parent 60a3824 commit ff17127

1 file changed

Lines changed: 327 additions & 0 deletions

File tree

tests/unit/param_mapping_test.py

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for param_mapping.py"""
16+
17+
import unittest
18+
from unittest import mock
19+
import numpy as np
20+
21+
from maxtext.checkpoint_conversion.utils import param_mapping
22+
23+
24+
class ParamMappingTest(unittest.TestCase):
25+
26+
def _execute_hooks(self, hooks, saving_to_hf):
27+
"""Executes all hooks in the dictionary with dummy data."""
28+
for key, hook_val in hooks.items():
29+
hook_list = hook_val if isinstance(hook_val, list) else [hook_val]
30+
for hook in hook_list:
31+
try:
32+
if isinstance(key, tuple):
33+
dummy_data = (np.ones((10, 20), dtype=np.float32), np.ones((10, 20), dtype=np.float32))
34+
target_shape = (10, 40)
35+
else:
36+
dummy_data = np.ones((10, 20), dtype=np.float32)
37+
target_shape = (10, 20)
38+
39+
_ = hook(dummy_data, target_shape)
40+
except Exception: # pylint: disable=broad-exception-caught
41+
pass
42+
43+
def test_gemma3_mapping_unscanned(self):
44+
config = {
45+
"text_config": {"num_hidden_layers": 2, "hidden_size": 256},
46+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
47+
}
48+
maxtext_config = mock.Mock()
49+
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
50+
51+
self.assertIn("params-token_embedder-embedding", mapping)
52+
self.assertEqual(mapping["params-token_embedder-embedding"], "model.language_model.embed_tokens.weight")
53+
54+
def test_gemma3_mapping_scanned(self):
55+
config = {
56+
"text_config": {"num_hidden_layers": 12, "hidden_size": 256},
57+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
58+
}
59+
maxtext_config = mock.Mock()
60+
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
61+
self.assertIn("params-token_embedder-embedding", mapping)
62+
63+
def test_gemma3_hooks(self):
64+
config = {
65+
"text_config": {"num_hidden_layers": 2, "hidden_size": 256},
66+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
67+
}
68+
maxtext_config = mock.Mock()
69+
hooks_to_hf = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(
70+
config, maxtext_config, scan_layers=False, saving_to_hf=True
71+
)
72+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
73+
74+
hooks_to_mt = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(
75+
config, maxtext_config, scan_layers=False, saving_to_hf=False
76+
)
77+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
78+
79+
def test_gemma2_mapping(self):
80+
config = {
81+
"num_hidden_layers": 4,
82+
"hidden_size": 256,
83+
}
84+
maxtext_config = mock.Mock()
85+
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
86+
self.assertIn("params-token_embedder-embedding", mapping)
87+
88+
def test_gemma2_mapping_scanned(self):
89+
config = {
90+
"num_hidden_layers": 4,
91+
"hidden_size": 256,
92+
}
93+
maxtext_config = mock.Mock()
94+
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
95+
self.assertIn("params-decoder-layers-pre_self_attention_norm_local-scale", mapping)
96+
97+
def test_gemma2_hooks(self):
98+
config = {
99+
"num_hidden_layers": 4,
100+
"hidden_size": 256,
101+
"head_dim": 64,
102+
}
103+
maxtext_config = mock.Mock()
104+
hooks_to_hf = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(
105+
config, maxtext_config, scan_layers=False, saving_to_hf=True
106+
)
107+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
108+
109+
hooks_to_mt = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(
110+
config, maxtext_config, scan_layers=False, saving_to_hf=False
111+
)
112+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
113+
114+
def test_qwen_mapping_dense(self):
115+
config = {
116+
"num_hidden_layers": 2,
117+
}
118+
maxtext_config = mock.Mock()
119+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
120+
self.assertIn("params-token_embedder-embedding", mapping)
121+
122+
def test_qwen_mapping_moe(self):
123+
config = {
124+
"num_hidden_layers": 2,
125+
"num_experts": 4,
126+
}
127+
maxtext_config = mock.Mock()
128+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
129+
self.assertIn("params-decoder-layers_0-moe_block-wi_0", mapping)
130+
131+
def test_qwen_mapping_scanned(self):
132+
config = {
133+
"num_hidden_layers": 4,
134+
"hidden_size": 256,
135+
}
136+
maxtext_config = mock.Mock()
137+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
138+
self.assertIn("params-decoder-layers-pre_self_attention_layer_norm-scale", mapping)
139+
140+
def test_qwen_hooks(self):
141+
config = {
142+
"num_hidden_layers": 2,
143+
"hidden_size": 256,
144+
}
145+
maxtext_config = mock.Mock()
146+
hooks_to_hf = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(
147+
config, maxtext_config, scan_layers=False, saving_to_hf=True
148+
)
149+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
150+
151+
hooks_to_mt = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(
152+
config, maxtext_config, scan_layers=False, saving_to_hf=False
153+
)
154+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
155+
156+
def test_qwen3_next_mapping(self):
157+
config = {
158+
"num_hidden_layers": 4,
159+
"num_experts": 2,
160+
}
161+
maxtext_config = mock.Mock()
162+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
163+
mapping = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
164+
self.assertIn("params-token_embedder-embedding", mapping)
165+
166+
def test_qwen3_next_mapping_scanned(self):
167+
config = {
168+
"num_hidden_layers": 4,
169+
"num_experts": 2,
170+
}
171+
maxtext_config = mock.Mock()
172+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
173+
mapping = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
174+
self.assertIn("params-decoder-layers-layer_0-input_layernorm-scale", mapping)
175+
176+
def test_qwen3_next_hooks(self):
177+
config = {
178+
"num_hidden_layers": 4,
179+
"num_experts": 2,
180+
"hidden_size": 256,
181+
}
182+
maxtext_config = mock.Mock()
183+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
184+
hooks_to_hf = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(
185+
config, maxtext_config, scan_layers=False, saving_to_hf=True
186+
)
187+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
188+
189+
hooks_to_mt = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(
190+
config, maxtext_config, scan_layers=False, saving_to_hf=False
191+
)
192+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
193+
194+
def test_deepseek_mapping(self):
195+
config = {
196+
"num_hidden_layers": 4,
197+
"first_k_dense_replace": 1,
198+
"n_routed_experts": 2,
199+
}
200+
maxtext_config = mock.Mock()
201+
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
202+
self.assertIn("params-token_embedder-embedding", mapping)
203+
204+
def test_deepseek_mapping_scanned(self):
205+
config = {
206+
"num_hidden_layers": 4,
207+
"first_k_dense_replace": 1,
208+
"n_routed_experts": 2,
209+
}
210+
maxtext_config = mock.Mock()
211+
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
212+
self.assertIn("params-decoder-dense_layers-self_attention-query-kernel", mapping)
213+
214+
def test_deepseek_hooks(self):
215+
config = {
216+
"num_hidden_layers": 4,
217+
"first_k_dense_replace": 1,
218+
"hidden_size": 256,
219+
}
220+
maxtext_config = mock.Mock()
221+
hooks_to_hf = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(
222+
config, maxtext_config, scan_layers=False, saving_to_hf=True
223+
)
224+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
225+
226+
hooks_to_mt = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(
227+
config, maxtext_config, scan_layers=False, saving_to_hf=False
228+
)
229+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
230+
231+
def test_gpt_oss_mapping(self):
232+
config = {
233+
"num_hidden_layers": 2,
234+
}
235+
maxtext_config = mock.Mock()
236+
maxtext_config.inhomogeneous_layer_cycle_interval = 1
237+
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
238+
self.assertIn("params-token_embedder-embedding", mapping)
239+
240+
def test_gpt_oss_mapping_scanned(self):
241+
config = {
242+
"num_hidden_layers": 4,
243+
}
244+
maxtext_config = mock.Mock()
245+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
246+
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
247+
self.assertIn("params-decoder-layers-layers_0-pre_self_attention_layer_norm-scale", mapping)
248+
249+
def test_gpt_oss_hooks(self):
250+
config = {
251+
"num_hidden_layers": 2,
252+
"hidden_size": 256,
253+
}
254+
maxtext_config = mock.Mock()
255+
maxtext_config.inhomogeneous_layer_cycle_interval = 1
256+
hooks_to_hf = param_mapping.GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
257+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
258+
259+
hooks_to_mt = param_mapping.GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False)
260+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
261+
262+
def test_mixtral_mapping(self):
263+
config = {
264+
"num_hidden_layers": 2,
265+
}
266+
maxtext_config = mock.Mock()
267+
maxtext_config.num_experts = 4
268+
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
269+
self.assertIn("params-token_embedder-embedding", mapping)
270+
271+
def test_mixtral_mapping_scanned(self):
272+
config = {
273+
"num_hidden_layers": 4,
274+
}
275+
276+
class Config:
277+
num_experts = 4
278+
279+
maxtext_config = Config()
280+
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
281+
282+
self.assertIn("params-decoder-layers-self_attention-query-kernel", mapping)
283+
284+
def test_mixtral_hooks(self):
285+
config = {
286+
"num_hidden_layers": 2,
287+
"hidden_size": 256,
288+
}
289+
maxtext_config = mock.Mock()
290+
maxtext_config.head_dim = 64
291+
hooks_to_hf = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN(
292+
config, maxtext_config, scan_layers=False, saving_to_hf=True
293+
)
294+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
295+
296+
hooks_to_mt = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN(
297+
config, maxtext_config, scan_layers=False, saving_to_hf=False
298+
)
299+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
300+
301+
def test_gemma4_mapping(self):
302+
config = {
303+
"num_hidden_layers": 2,
304+
}
305+
maxtext_config = mock.Mock()
306+
maxtext_config.share_kv_projections = False
307+
maxtext_config.use_multimodal = False
308+
maxtext_config.v_norm_with_scale = False
309+
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
310+
311+
self.assertIn("params-token_embedder-embedding", mapping)
312+
313+
def test_gemma4_mapping_scanned(self):
314+
config = {
315+
"num_hidden_layers": 12,
316+
}
317+
maxtext_config = mock.Mock()
318+
maxtext_config.share_kv_projections = False
319+
maxtext_config.use_multimodal = False
320+
maxtext_config.v_norm_with_scale = False
321+
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
322+
323+
self.assertIn("params-decoder-scanned_blocks-layers_0-self_attention-query-kernel", mapping)
324+
325+
326+
if __name__ == "__main__":
327+
unittest.main()

0 commit comments

Comments
 (0)