Skip to content

Commit 5f81bf3

Browse files
committed
Add unit tests for param_mapping.py (mappings and specific hooks)
1 parent 60a3824 commit 5f81bf3

1 file changed

Lines changed: 238 additions & 0 deletions

File tree

tests/unit/param_mapping_test.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for param_mapping.py"""
16+
17+
import unittest
18+
from unittest import mock
19+
import numpy as np
20+
import jax
21+
22+
from maxtext.checkpoint_conversion.utils import param_mapping
23+
24+
25+
class ParamMappingTest(unittest.TestCase):
26+
27+
def test_gemma3_mapping_unscanned(self):
28+
config = {
29+
"text_config": {"num_hidden_layers": 2, "hidden_size": 256},
30+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
31+
}
32+
maxtext_config = mock.Mock()
33+
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
34+
self.assertIn("params-token_embedder-embedding", mapping)
35+
36+
def test_gemma3_mapping_scanned(self):
37+
config = {
38+
"text_config": {"num_hidden_layers": 12, "hidden_size": 256},
39+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
40+
}
41+
maxtext_config = mock.Mock()
42+
mapping = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
43+
self.assertIn("params-token_embedder-embedding", mapping)
44+
45+
def test_gemma2_mapping(self):
46+
config = {
47+
"num_hidden_layers": 4,
48+
"hidden_size": 256,
49+
}
50+
maxtext_config = mock.Mock()
51+
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
52+
self.assertIn("params-token_embedder-embedding", mapping)
53+
54+
def test_gemma2_mapping_scanned(self):
55+
config = {
56+
"num_hidden_layers": 4,
57+
"hidden_size": 256,
58+
}
59+
maxtext_config = mock.Mock()
60+
mapping = param_mapping.GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
61+
self.assertIn("params-decoder-layers-pre_self_attention_norm_local-scale", mapping)
62+
63+
def test_qwen_mapping_dense(self):
64+
config = {
65+
"num_hidden_layers": 2,
66+
}
67+
maxtext_config = mock.Mock()
68+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
69+
self.assertIn("params-token_embedder-embedding", mapping)
70+
71+
def test_qwen_mapping_moe(self):
72+
config = {
73+
"num_hidden_layers": 2,
74+
"num_experts": 4,
75+
}
76+
maxtext_config = mock.Mock()
77+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
78+
self.assertIn("params-decoder-layers_0-moe_block-wi_0", mapping)
79+
80+
def test_qwen_mapping_scanned(self):
81+
config = {
82+
"num_hidden_layers": 4,
83+
"hidden_size": 256,
84+
}
85+
maxtext_config = mock.Mock()
86+
mapping = param_mapping.QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
87+
self.assertIn("params-decoder-layers-pre_self_attention_layer_norm-scale", mapping)
88+
89+
def test_qwen3_next_mapping(self):
90+
config = {
91+
"num_hidden_layers": 4,
92+
"num_experts": 2,
93+
}
94+
maxtext_config = mock.Mock()
95+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
96+
mapping = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
97+
self.assertIn("params-token_embedder-embedding", mapping)
98+
99+
def test_qwen3_next_mapping_scanned(self):
100+
config = {
101+
"num_hidden_layers": 4,
102+
"num_experts": 2,
103+
}
104+
maxtext_config = mock.Mock()
105+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
106+
mapping = param_mapping.QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
107+
self.assertIn("params-decoder-layers-layer_0-input_layernorm-scale", mapping)
108+
109+
def test_deepseek_mapping(self):
110+
config = {
111+
"num_hidden_layers": 4,
112+
"first_k_dense_replace": 1,
113+
"n_routed_experts": 2,
114+
}
115+
maxtext_config = mock.Mock()
116+
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
117+
self.assertIn("params-token_embedder-embedding", mapping)
118+
119+
def test_deepseek_mapping_scanned(self):
120+
config = {
121+
"num_hidden_layers": 4,
122+
"first_k_dense_replace": 1,
123+
"n_routed_experts": 2,
124+
}
125+
maxtext_config = mock.Mock()
126+
mapping = param_mapping.DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
127+
self.assertIn("params-decoder-dense_layers-self_attention-query-kernel", mapping)
128+
129+
def test_gpt_oss_mapping(self):
130+
config = {
131+
"num_hidden_layers": 2,
132+
}
133+
maxtext_config = mock.Mock()
134+
maxtext_config.inhomogeneous_layer_cycle_interval = 1
135+
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
136+
self.assertIn("params-token_embedder-embedding", mapping)
137+
138+
def test_gpt_oss_mapping_scanned(self):
139+
config = {
140+
"num_hidden_layers": 4,
141+
}
142+
maxtext_config = mock.Mock()
143+
maxtext_config.inhomogeneous_layer_cycle_interval = 2
144+
mapping = param_mapping.GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
145+
self.assertIn("params-decoder-layers-layers_0-pre_self_attention_layer_norm-scale", mapping)
146+
147+
def test_mixtral_mapping(self):
148+
config = {
149+
"num_hidden_layers": 2,
150+
}
151+
maxtext_config = mock.Mock()
152+
maxtext_config.num_experts = 4
153+
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
154+
self.assertIn("params-token_embedder-embedding", mapping)
155+
156+
def test_mixtral_mapping_scanned(self):
157+
config = {
158+
"num_hidden_layers": 4,
159+
}
160+
class Config:
161+
num_experts = 4
162+
maxtext_config = Config()
163+
mapping = param_mapping.MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
164+
self.assertIn("params-decoder-layers-self_attention-query-kernel", mapping)
165+
166+
def test_gemma4_mapping(self):
167+
config = {
168+
"num_hidden_layers": 2,
169+
}
170+
maxtext_config = mock.Mock()
171+
maxtext_config.share_kv_projections = False
172+
maxtext_config.use_multimodal = False
173+
maxtext_config.v_norm_with_scale = False
174+
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
175+
self.assertIn("params-token_embedder-embedding", mapping)
176+
177+
def test_gemma4_mapping_scanned(self):
178+
config = {
179+
"num_hidden_layers": 12,
180+
}
181+
maxtext_config = mock.Mock()
182+
maxtext_config.share_kv_projections = False
183+
maxtext_config.use_multimodal = False
184+
maxtext_config.v_norm_with_scale = False
185+
mapping = param_mapping.GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=True)
186+
self.assertIn("params-decoder-scanned_blocks-layers_0-self_attention-query-kernel", mapping)
187+
188+
# Specific tests with assertions
189+
def test_reshape_kernel_hook(self):
190+
config = {
191+
"text_config": {"num_hidden_layers": 2, "hidden_size": 256},
192+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
193+
}
194+
maxtext_config = mock.Mock()
195+
hooks = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
196+
reshape_key = "params-decoder-layers_0-self_attention-query-kernel"
197+
reshape_hook = hooks[reshape_key]
198+
199+
dummy_tensor = np.arange(6).reshape(2, 3).astype(np.float32)
200+
target_shape = (3, 2)
201+
output = reshape_hook(dummy_tensor, target_shape)
202+
expected_output = dummy_tensor.T
203+
np.testing.assert_allclose(output, expected_output)
204+
205+
def test_scale_rmsnorm_hook(self):
206+
config = {
207+
"text_config": {"num_hidden_layers": 2, "hidden_size": 256},
208+
"vision_config": {"num_hidden_layers": 1, "hidden_size": 128},
209+
}
210+
maxtext_config = mock.Mock()
211+
hooks_to_hf = param_mapping.GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
212+
norm_key = "params-decoder-layers_0-pre_self_attention_norm-scale"
213+
norm_hook_to_hf = hooks_to_hf[norm_key]
214+
215+
dummy_tensor = np.array([2.0, 3.0], dtype=np.float32)
216+
output = norm_hook_to_hf(dummy_tensor, (2,))
217+
np.testing.assert_allclose(output, np.array([1.0, 2.0]))
218+
219+
def test_interleave_hook(self):
220+
config = {
221+
"num_hidden_layers": 2,
222+
}
223+
maxtext_config = mock.Mock()
224+
maxtext_config.inhomogeneous_layer_cycle_interval = 1
225+
hooks_to_hf = param_mapping.GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=True)
226+
composite_key = ("params-decoder-layers_0-GptOssMlp-wi_0", "params-decoder-layers_0-GptOssMlp-wi_1")
227+
interleave_hook = hooks_to_hf[composite_key]
228+
229+
wi_0 = np.array([1, 2], dtype=np.float32)
230+
wi_1 = np.array([3, 4], dtype=np.float32)
231+
232+
output = interleave_hook((wi_0, wi_1), (4,))
233+
expected_output = np.array([1, 3, 2, 4], dtype=np.float32)
234+
np.testing.assert_allclose(output, expected_output)
235+
236+
237+
if __name__ == "__main__":
238+
unittest.main()

0 commit comments

Comments
 (0)