Skip to content

Commit ed0d5e6

Browse files
Merge pull request #3693 from AI-Hypercomputer:add-linears-tests
PiperOrigin-RevId: 902803301
2 parents 3de696b + 4a48e26 commit ed0d5e6

1 file changed

Lines changed: 204 additions & 0 deletions

File tree

tests/unit/linears_test.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for linears.py."""
16+
17+
import sys
18+
import unittest
19+
from flax import nnx
20+
import jax
21+
import jax.numpy as jnp
22+
import numpy as np
23+
24+
from maxtext.layers import linears
25+
from maxtext.configs import pyconfig
26+
from maxtext.utils import maxtext_utils
27+
from tests.utils.test_helpers import get_test_config_path
28+
29+
30+
class UtilsTest(unittest.TestCase):
31+
"""Tests for utility functions in linears.py."""
32+
33+
def test_normalize_axes(self):
34+
self.assertEqual(linears.normalize_axes((1, 2), 4), (1, 2))
35+
self.assertEqual(linears.normalize_axes((-1, -2), 4), (3, 2))
36+
self.assertEqual(linears.normalize_axes((0, -1), 3), (0, 2))
37+
38+
def test_canonicalize_tuple(self):
39+
self.assertEqual(linears.canonicalize_tuple(1), (1,))
40+
self.assertEqual(linears.canonicalize_tuple((1, 2)), (1, 2))
41+
self.assertEqual(linears.canonicalize_tuple([1, 2]), (1, 2))
42+
43+
# pylint: disable=protected-access
44+
def test_convert_to_activation_function(self):
45+
lin_fn = linears._convert_to_activation_function("linear")
46+
x = jnp.array([1.0, 2.0])
47+
np.testing.assert_array_equal(lin_fn(x), x)
48+
49+
relu_fn = linears._convert_to_activation_function("relu")
50+
x = jnp.array([-1.0, 2.0])
51+
np.testing.assert_array_equal(relu_fn(x), jnp.array([0.0, 2.0]))
52+
53+
# Test with callable
54+
def dummy_fn(x):
55+
return x + 1
56+
57+
self.assertEqual(linears._convert_to_activation_function(dummy_fn), dummy_fn)
58+
59+
with self.assertRaises(ValueError):
60+
linears._convert_to_activation_function(123)
61+
62+
63+
class DenseGeneralTest(unittest.TestCase):
64+
"""Tests for DenseGeneral."""
65+
66+
def setUp(self):
67+
super().setUp()
68+
self.rngs = nnx.Rngs(params=0)
69+
70+
def test_basic_call(self):
71+
batch_size = 2
72+
in_features = 4
73+
out_features = 8
74+
75+
layer = linears.DenseGeneral(
76+
in_features_shape=in_features,
77+
out_features_shape=out_features,
78+
rngs=self.rngs,
79+
)
80+
81+
inputs = jnp.ones((batch_size, in_features))
82+
outputs = layer(inputs)
83+
84+
self.assertEqual(outputs.shape, (batch_size, out_features))
85+
86+
def test_bias(self):
87+
batch_size = 2
88+
in_features = 4
89+
out_features = 8
90+
91+
layer = linears.DenseGeneral(
92+
in_features_shape=in_features,
93+
out_features_shape=out_features,
94+
use_bias=True,
95+
rngs=self.rngs,
96+
)
97+
98+
inputs = jnp.ones((batch_size, in_features))
99+
outputs = layer(inputs)
100+
101+
self.assertEqual(outputs.shape, (batch_size, out_features))
102+
self.assertIsNotNone(layer.bias)
103+
104+
def _run_dense_test(self, axis, in_feat_shape, expected_shape):
105+
batch_size = 2
106+
seq_len = 3
107+
in_features = 4
108+
out_features = 8
109+
110+
layer = linears.DenseGeneral(
111+
in_features_shape=in_feat_shape,
112+
out_features_shape=out_features,
113+
axis=axis,
114+
rngs=self.rngs,
115+
)
116+
117+
inputs = jnp.ones((batch_size, seq_len, in_features))
118+
outputs = layer(inputs)
119+
120+
self.assertEqual(outputs.shape, expected_shape)
121+
122+
def test_axis_neg_1(self):
123+
self._run_dense_test(-1, 4, (2, 3, 8))
124+
125+
def test_axis_1(self):
126+
self._run_dense_test(1, 3, (2, 4, 8))
127+
128+
def test_axis_0(self):
129+
self._run_dense_test(0, 2, (3, 4, 8))
130+
131+
132+
class MlpBlockTest(unittest.TestCase):
133+
"""Tests for MlpBlock."""
134+
135+
def setUp(self):
136+
super().setUp()
137+
self.rngs = nnx.Rngs(params=0, dropout=1)
138+
139+
config_arguments = {
140+
"per_device_batch_size": 1.0,
141+
"run_name": "test",
142+
"enable_checkpointing": False,
143+
"max_target_length": 128,
144+
"fused_mlp": False,
145+
}
146+
argv = [sys.argv[0], get_test_config_path()]
147+
self.cfg = pyconfig.initialize(argv, **config_arguments)
148+
149+
devices_array = maxtext_utils.create_device_mesh(self.cfg)
150+
self.mesh = jax.sharding.Mesh(devices_array, self.cfg.mesh_axes)
151+
152+
def test_basic_call(self):
153+
batch_size = 2
154+
seq_len = 3
155+
in_features = 4
156+
intermediate_dim = 8
157+
158+
layer = linears.MlpBlock(
159+
config=self.cfg,
160+
mesh=self.mesh,
161+
in_features=in_features,
162+
intermediate_dim=intermediate_dim,
163+
rngs=self.rngs,
164+
)
165+
166+
inputs = jnp.ones((batch_size, seq_len, in_features))
167+
outputs = layer(inputs)
168+
169+
self.assertEqual(outputs.shape, (batch_size, seq_len, in_features))
170+
self.assertEqual(layer.wi.kernel[...].shape, (in_features, intermediate_dim))
171+
172+
def test_fused_mlp(self):
173+
batch_size = 2
174+
seq_len = 3
175+
in_features = 4
176+
intermediate_dim = 8
177+
178+
config_arguments = {
179+
"per_device_batch_size": 1.0,
180+
"run_name": "test",
181+
"enable_checkpointing": False,
182+
"max_target_length": 128,
183+
"fused_mlp": True,
184+
}
185+
argv = [sys.argv[0], get_test_config_path()]
186+
cfg_fused = pyconfig.initialize(argv, **config_arguments)
187+
188+
layer = linears.MlpBlock(
189+
config=cfg_fused,
190+
mesh=self.mesh,
191+
in_features=in_features,
192+
intermediate_dim=intermediate_dim,
193+
rngs=self.rngs,
194+
)
195+
196+
inputs = jnp.ones((batch_size, seq_len, in_features))
197+
outputs = layer(inputs)
198+
199+
self.assertEqual(outputs.shape, (batch_size, seq_len, in_features))
200+
self.assertEqual(layer.wi.kernel[...].shape, (in_features, 1, intermediate_dim))
201+
202+
203+
if __name__ == "__main__":
204+
unittest.main()

0 commit comments

Comments
 (0)