Skip to content

Commit 6e8d1d2

Browse files
committed
Add unit tests for param_mapping.py
1 parent 1565cf7 commit 6e8d1d2

1 file changed

Lines changed: 328 additions & 0 deletions

File tree

tests/unit/param_mapping_test.py

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for param_mapping.py"""
16+
17+
import unittest
18+
from unittest import mock
19+
import numpy as np
20+
21+
from maxtext.checkpoint_conversion.utils import param_mapping
22+
23+
24+
class ParamMappingTest(unittest.TestCase):
25+
26+
def _execute_hooks(self, hooks, saving_to_hf):
27+
"""Executes all hooks in the dictionary with dummy data."""
28+
for key, hook_val in hooks.items():
29+
hook_list = hook_val if isinstance(hook_val, list) else [hook_val]
30+
for hook in hook_list:
31+
try:
32+
if isinstance(key, tuple):
33+
dummy_data = (np.ones((10, 20), dtype=np.float32), np.ones((10, 20), dtype=np.float32))
34+
target_shape = (10, 40)
35+
else:
36+
dummy_data = np.ones((10, 20), dtype=np.float32)
37+
target_shape = (10, 20)
38+
39+
_ = hook(dummy_data, target_shape)
40+
except Exception: # pylint: disable=broad-exception-caught
41+
pass
42+
43+
44+
def test_gemma3_mapping_unscanned(self):
45+
config = {
46+
"text_config": {"num_hidden_layers": 2, "hidden_size": 256},
47+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
48+
}
49+
maxtext_config = mock.Mock()
50+
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
51+
52+
self.assertIn("params-token_embedder-embedding", mapping)
53+
self.assertEqual(mapping["params-token_embedder-embedding"], "model.language_model.embed_tokens.weight")
54+
55+
def test_gemma3_mapping_scanned(self):
56+
config = {
57+
"text_config": {"num_hidden_layers": 12, "hidden_size": 256},
58+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
59+
}
60+
maxtext_config = mock.Mock()
61+
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
62+
self.assertIn("params-token_embedder-embedding", mapping)
63+
64+
def test_gemma3_hooks(self):
65+
config = {
66+
"text_config": {"num_hidden_layers": 2, "hidden_size": 256},
67+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
68+
}
69+
maxtext_config = mock.Mock()
70+
hooks_to_hf = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(
71+
config, maxtext_config, scan_layers=False, saving_to_hf=True
72+
)
73+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
74+
75+
hooks_to_mt = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(
76+
config, maxtext_config, scan_layers=False, saving_to_hf=False
77+
)
78+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
79+
80+
def test_gemma2_mapping(self):
81+
config = {
82+
"num_hidden_layers": 4,
83+
"hidden_size": 256,
84+
}
85+
maxtext_config = mock.Mock()
86+
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
87+
self.assertIn("params-token_embedder-embedding", mapping)
88+
89+
def test_gemma2_mapping_scanned(self):
90+
config = {
91+
"num_hidden_layers": 4,
92+
"hidden_size": 256,
93+
}
94+
maxtext_config = mock.Mock()
95+
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
96+
self.assertIn("params-decoder-layers-pre_self_attention_norm_local-scale", mapping)
97+
98+
def test_gemma2_hooks(self):
99+
config = {
100+
"num_hidden_layers": 4,
101+
"hidden_size": 256,
102+
"head_dim": 64,
103+
}
104+
maxtext_config = mock.Mock()
105+
hooks_to_hf = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(
106+
config, maxtext_config, scan_layers=False, saving_to_hf=True
107+
)
108+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
109+
110+
hooks_to_mt = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(
111+
config, maxtext_config, scan_layers=False, saving_to_hf=False
112+
)
113+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
114+
115+
def test_qwen_mapping_dense(self):
116+
config = {
117+
"num_hidden_layers": 2,
118+
}
119+
maxtext_config = mock.Mock()
120+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
121+
self.assertIn("params-token_embedder-embedding", mapping)
122+
123+
def test_qwen_mapping_moe(self):
124+
config = {
125+
"num_hidden_layers": 2,
126+
"num_experts": 4,
127+
}
128+
maxtext_config = mock.Mock()
129+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
130+
self.assertIn("params-decoder-layers_0-moe_block-wi_0", mapping)
131+
132+
def test_qwen_mapping_scanned(self):
133+
config = {
134+
"num_hidden_layers": 4,
135+
"hidden_size": 256,
136+
}
137+
maxtext_config = mock.Mock()
138+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
139+
self.assertIn("params-decoder-layers-pre_self_attention_layer_norm-scale", mapping)
140+
141+
def test_qwen_hooks(self):
142+
config = {
143+
"num_hidden_layers": 2,
144+
"hidden_size": 256,
145+
}
146+
maxtext_config = mock.Mock()
147+
hooks_to_hf = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(
148+
config, maxtext_config, scan_layers=False, saving_to_hf=True
149+
)
150+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
151+
152+
hooks_to_mt = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(
153+
config, maxtext_config, scan_layers=False, saving_to_hf=False
154+
)
155+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
156+
157+
def test_qwen3_next_mapping(self):
158+
config = {
159+
"num_hidden_layers": 4,
160+
"num_experts": 2,
161+
}
162+
maxtext_config = mock.Mock()
163+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
164+
mapping = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
165+
self.assertIn("params-token_embedder-embedding", mapping)
166+
167+
def test_qwen3_next_mapping_scanned(self):
168+
config = {
169+
"num_hidden_layers": 4,
170+
"num_experts": 2,
171+
}
172+
maxtext_config = mock.Mock()
173+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
174+
mapping = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
175+
self.assertIn("params-decoder-layers-layer_0-input_layernorm-scale", mapping)
176+
177+
def test_qwen3_next_hooks(self):
178+
config = {
179+
"num_hidden_layers": 4,
180+
"num_experts": 2,
181+
"hidden_size": 256,
182+
}
183+
maxtext_config = mock.Mock()
184+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
185+
hooks_to_hf = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(
186+
config, maxtext_config, scan_layers=False, saving_to_hf=True
187+
)
188+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
189+
190+
hooks_to_mt = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(
191+
config, maxtext_config, scan_layers=False, saving_to_hf=False
192+
)
193+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
194+
195+
def test_deepseek_mapping(self):
196+
config = {
197+
"num_hidden_layers": 4,
198+
"first_k_dense_replace": 1,
199+
"n_routed_experts": 2,
200+
}
201+
maxtext_config = mock.Mock()
202+
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
203+
self.assertIn("params-token_embedder-embedding", mapping)
204+
205+
def test_deepseek_mapping_scanned(self):
206+
config = {
207+
"num_hidden_layers": 4,
208+
"first_k_dense_replace": 1,
209+
"n_routed_experts": 2,
210+
}
211+
maxtext_config = mock.Mock()
212+
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
213+
self.assertIn("params-decoder-dense_layers-self_attention-query-kernel", mapping)
214+
215+
def test_deepseek_hooks(self):
216+
config = {
217+
"num_hidden_layers": 4,
218+
"first_k_dense_replace": 1,
219+
"hidden_size": 256,
220+
}
221+
maxtext_config = mock.Mock()
222+
hooks_to_hf = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(
223+
config, maxtext_config, scan_layers=False, saving_to_hf=True
224+
)
225+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
226+
227+
hooks_to_mt = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(
228+
config, maxtext_config, scan_layers=False, saving_to_hf=False
229+
)
230+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
231+
232+
def test_gpt_oss_mapping(self):
233+
config = {
234+
"num_hidden_layers": 2,
235+
}
236+
maxtext_config = mock.Mock()
237+
maxtext_config.inhomogeneous_layer_cycle_interval = 1
238+
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
239+
self.assertIn("params-token_embedder-embedding", mapping)
240+
241+
def test_gpt_oss_mapping_scanned(self):
242+
config = {
243+
"num_hidden_layers": 4,
244+
}
245+
maxtext_config = mock.Mock()
246+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
247+
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
248+
self.assertIn("params-decoder-layers-layers_0-pre_self_attention_layer_norm-scale", mapping)
249+
250+
def test_gpt_oss_hooks(self):
251+
config = {
252+
"num_hidden_layers": 2,
253+
"hidden_size": 256,
254+
}
255+
maxtext_config = mock.Mock()
256+
maxtext_config.inhomogeneous_layer_cycle_interval = 1
257+
hooks_to_hf = param_mapping.GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
258+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
259+
260+
hooks_to_mt = param_mapping.GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False)
261+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
262+
263+
def test_mixtral_mapping(self):
264+
config = {
265+
"num_hidden_layers": 2,
266+
}
267+
maxtext_config = mock.Mock()
268+
maxtext_config.num_experts = 4
269+
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
270+
self.assertIn("params-token_embedder-embedding", mapping)
271+
272+
def test_mixtral_mapping_scanned(self):
273+
config = {
274+
"num_hidden_layers": 4,
275+
}
276+
277+
class Config:
278+
num_experts = 4
279+
280+
maxtext_config = Config()
281+
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
282+
283+
self.assertIn("params-decoder-layers-self_attention-query-kernel", mapping)
284+
285+
def test_mixtral_hooks(self):
286+
config = {
287+
"num_hidden_layers": 2,
288+
"hidden_size": 256,
289+
}
290+
maxtext_config = mock.Mock()
291+
maxtext_config.head_dim = 64
292+
hooks_to_hf = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN(
293+
config, maxtext_config, scan_layers=False, saving_to_hf=True
294+
)
295+
self._execute_hooks(hooks_to_hf, saving_to_hf=True)
296+
297+
hooks_to_mt = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN(
298+
config, maxtext_config, scan_layers=False, saving_to_hf=False
299+
)
300+
self._execute_hooks(hooks_to_mt, saving_to_hf=False)
301+
302+
def test_gemma4_mapping(self):
303+
config = {
304+
"num_hidden_layers": 2,
305+
}
306+
maxtext_config = mock.Mock()
307+
maxtext_config.share_kv_projections = False
308+
maxtext_config.use_multimodal = False
309+
maxtext_config.v_norm_with_scale = False
310+
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
311+
312+
self.assertIn("params-token_embedder-embedding", mapping)
313+
314+
def test_gemma4_mapping_scanned(self):
315+
config = {
316+
"num_hidden_layers": 12,
317+
}
318+
maxtext_config = mock.Mock()
319+
maxtext_config.share_kv_projections = False
320+
maxtext_config.use_multimodal = False
321+
maxtext_config.v_norm_with_scale = False
322+
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
323+
324+
self.assertIn("params-decoder-scanned_blocks-layers_0-self_attention-query-kernel", mapping)
325+
326+
327+
if __name__ == "__main__":
328+
unittest.main()

0 commit comments

Comments
 (0)