Skip to content

Commit 6655116

Browse files
committed
Add unit tests for param_mapping.py (mappings and specific hooks)
1 parent 3fc8a2b commit 6655116

1 file changed

Lines changed: 241 additions & 0 deletions

File tree

tests/unit/param_mapping_test.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for param_mapping.py"""
16+
17+
import unittest
18+
from unittest import mock
19+
import numpy as np
20+
21+
from maxtext.checkpoint_conversion.utils import param_mapping
22+
23+
24+
class ParamMappingTest(unittest.TestCase):
25+
26+
def test_gemma3_mapping_unscanned(self):
27+
config = {
28+
"text_config": {"num_hidden_layers": 2, "hidden_size": 256},
29+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
30+
}
31+
maxtext_config = mock.Mock()
32+
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
33+
self.assertIn("params-token_embedder-embedding", mapping)
34+
35+
def test_gemma3_mapping_scanned(self):
36+
config = {
37+
"text_config": {"num_hidden_layers": 12, "hidden_size": 256},
38+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
39+
}
40+
maxtext_config = mock.Mock()
41+
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
42+
self.assertIn("params-token_embedder-embedding", mapping)
43+
44+
def test_gemma2_mapping(self):
45+
config = {
46+
"num_hidden_layers": 4,
47+
"hidden_size": 256,
48+
}
49+
maxtext_config = mock.Mock()
50+
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
51+
self.assertIn("params-token_embedder-embedding", mapping)
52+
53+
def test_gemma2_mapping_scanned(self):
54+
config = {
55+
"num_hidden_layers": 4,
56+
"hidden_size": 256,
57+
}
58+
maxtext_config = mock.Mock()
59+
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
60+
self.assertIn("params-decoder-layers-pre_self_attention_norm_local-scale", mapping)
61+
62+
def test_qwen_mapping_dense(self):
63+
config = {
64+
"num_hidden_layers": 2,
65+
}
66+
maxtext_config = mock.Mock()
67+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
68+
self.assertIn("params-token_embedder-embedding", mapping)
69+
70+
def test_qwen_mapping_moe(self):
71+
config = {
72+
"num_hidden_layers": 2,
73+
"num_experts": 4,
74+
}
75+
maxtext_config = mock.Mock()
76+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
77+
self.assertIn("params-decoder-layers_0-moe_block-wi_0", mapping)
78+
79+
def test_qwen_mapping_scanned(self):
80+
config = {
81+
"num_hidden_layers": 4,
82+
"hidden_size": 256,
83+
}
84+
maxtext_config = mock.Mock()
85+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
86+
self.assertIn("params-decoder-layers-pre_self_attention_layer_norm-scale", mapping)
87+
88+
def test_qwen3_next_mapping(self):
89+
config = {
90+
"num_hidden_layers": 4,
91+
"num_experts": 2,
92+
}
93+
maxtext_config = mock.Mock()
94+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
95+
mapping = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
96+
self.assertIn("params-token_embedder-embedding", mapping)
97+
98+
def test_qwen3_next_mapping_scanned(self):
99+
config = {
100+
"num_hidden_layers": 4,
101+
"num_experts": 2,
102+
}
103+
maxtext_config = mock.Mock()
104+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
105+
mapping = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
106+
self.assertIn("params-decoder-layers-layer_0-input_layernorm-scale", mapping)
107+
108+
def test_deepseek_mapping(self):
109+
config = {
110+
"num_hidden_layers": 4,
111+
"first_k_dense_replace": 1,
112+
"n_routed_experts": 2,
113+
}
114+
maxtext_config = mock.Mock()
115+
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
116+
self.assertIn("params-token_embedder-embedding", mapping)
117+
118+
def test_deepseek_mapping_scanned(self):
119+
config = {
120+
"num_hidden_layers": 4,
121+
"first_k_dense_replace": 1,
122+
"n_routed_experts": 2,
123+
}
124+
maxtext_config = mock.Mock()
125+
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
126+
self.assertIn("params-decoder-dense_layers-self_attention-query-kernel", mapping)
127+
128+
def test_gpt_oss_mapping(self):
129+
config = {
130+
"num_hidden_layers": 2,
131+
}
132+
maxtext_config = mock.Mock()
133+
maxtext_config.inhomogeneous_layer_cycle_interval = 1
134+
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
135+
self.assertIn("params-token_embedder-embedding", mapping)
136+
137+
def test_gpt_oss_mapping_scanned(self):
138+
config = {
139+
"num_hidden_layers": 4,
140+
}
141+
maxtext_config = mock.Mock()
142+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
143+
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
144+
self.assertIn("params-decoder-layers-layers_0-pre_self_attention_layer_norm-scale", mapping)
145+
146+
def test_mixtral_mapping(self):
147+
config = {
148+
"num_hidden_layers": 2,
149+
}
150+
maxtext_config = mock.Mock()
151+
maxtext_config.num_experts = 4
152+
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
153+
self.assertIn("params-token_embedder-embedding", mapping)
154+
155+
def test_mixtral_mapping_scanned(self):
156+
config = {
157+
"num_hidden_layers": 4,
158+
}
159+
160+
class Config:
161+
num_experts = 4
162+
163+
maxtext_config = Config()
164+
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
165+
self.assertIn("params-decoder-layers-self_attention-query-kernel", mapping)
166+
167+
def test_gemma4_mapping(self):
168+
config = {
169+
"num_hidden_layers": 2,
170+
}
171+
maxtext_config = mock.Mock()
172+
maxtext_config.share_kv_projections = False
173+
maxtext_config.use_multimodal = False
174+
maxtext_config.v_norm_with_scale = False
175+
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
176+
self.assertIn("params-token_embedder-embedding", mapping)
177+
178+
def test_gemma4_mapping_scanned(self):
179+
config = {
180+
"num_hidden_layers": 12,
181+
}
182+
maxtext_config = mock.Mock()
183+
maxtext_config.share_kv_projections = False
184+
maxtext_config.use_multimodal = False
185+
maxtext_config.v_norm_with_scale = False
186+
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
187+
self.assertIn("params-decoder-scanned_blocks-layers_0-self_attention-query-kernel", mapping)
188+
189+
# Specific tests with assertions
190+
def test_reshape_kernel_hook(self):
191+
config = {
192+
"text_config": {"num_hidden_layers": 2, "hidden_size": 256},
193+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
194+
}
195+
maxtext_config = mock.Mock()
196+
hooks = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
197+
reshape_key = "params-decoder-layers_0-self_attention-query-kernel"
198+
reshape_hook = hooks[reshape_key]
199+
200+
dummy_tensor = np.arange(6).reshape(2, 3).astype(np.float32)
201+
target_shape = (3, 2)
202+
output = reshape_hook(dummy_tensor, target_shape)
203+
expected_output = dummy_tensor.T
204+
np.testing.assert_allclose(output, expected_output)
205+
206+
def test_scale_rmsnorm_hook(self):
207+
config = {
208+
"text_config": {"num_hidden_layers": 2, "hidden_size": 256},
209+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
210+
}
211+
maxtext_config = mock.Mock()
212+
hooks_to_hf = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(
213+
config, maxtext_config, scan_layers=False, saving_to_hf=True
214+
)
215+
norm_key = "params-decoder-layers_0-pre_self_attention_norm-scale"
216+
norm_hook_to_hf = hooks_to_hf[norm_key]
217+
218+
dummy_tensor = np.array([2.0, 3.0], dtype=np.float32)
219+
output = norm_hook_to_hf(dummy_tensor, (2,))
220+
np.testing.assert_allclose(output, np.array([1.0, 2.0]))
221+
222+
def test_interleave_hook(self):
223+
config = {
224+
"num_hidden_layers": 2,
225+
}
226+
maxtext_config = mock.Mock()
227+
maxtext_config.inhomogeneous_layer_cycle_interval = 1
228+
hooks_to_hf = param_mapping.GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
229+
composite_key = ("params-decoder-layers_0-GptOssMlp-wi_0", "params-decoder-layers_0-GptOssMlp-wi_1")
230+
interleave_hook = hooks_to_hf[composite_key]
231+
232+
wi_0 = np.array([1, 2], dtype=np.float32)
233+
wi_1 = np.array([3, 4], dtype=np.float32)
234+
235+
output = interleave_hook((wi_0, wi_1), (4,))
236+
expected_output = np.array([1, 3, 2, 4], dtype=np.float32)
237+
np.testing.assert_allclose(output, expected_output)
238+
239+
240+
if __name__ == "__main__":
241+
unittest.main()

0 commit comments

Comments
 (0)