Skip to content

Commit 59cf305

Browse files
Merge pull request #3706 from AI-Hypercomputer:add-embeddings-tests
PiperOrigin-RevId: 902870201
2 parents 5182e3b + e103e63 commit 59cf305

2 files changed

Lines changed: 236 additions & 0 deletions

File tree

tests/unit/embeddings_test.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for embeddings.py."""
16+
17+
import sys
18+
import unittest
19+
from flax import nnx
20+
import jax
21+
import jax.numpy as jnp
22+
import numpy as np
23+
24+
from maxtext.layers import embeddings
25+
from maxtext.configs import pyconfig
26+
from maxtext.utils import maxtext_utils
27+
from tests.utils.test_helpers import get_test_config_path
28+
29+
30+
class EmbedTest(unittest.TestCase):
31+
"""Tests for Embed."""
32+
33+
def setUp(self):
34+
super().setUp()
35+
self.rngs = nnx.Rngs(params=0)
36+
37+
config_arguments = {
38+
"per_device_batch_size": 1.0,
39+
"run_name": "test",
40+
"enable_checkpointing": False,
41+
"max_target_length": 128,
42+
}
43+
argv = [sys.argv[0], get_test_config_path()]
44+
self.cfg = pyconfig.initialize(argv, **config_arguments)
45+
46+
devices_array = maxtext_utils.create_device_mesh(self.cfg)
47+
self.mesh = jax.sharding.Mesh(devices_array, self.cfg.mesh_axes)
48+
49+
def test_basic_call(self):
50+
num_embeddings = 100
51+
num_features = 16
52+
batch_size = 2
53+
seq_len = 3
54+
55+
layer = embeddings.Embed(
56+
num_embeddings=num_embeddings,
57+
num_features=num_features,
58+
config=self.cfg,
59+
mesh=self.mesh,
60+
rngs=self.rngs,
61+
)
62+
63+
inputs = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
64+
outputs = layer(inputs)
65+
66+
self.assertEqual(outputs.shape, (batch_size, seq_len, num_features))
67+
68+
def test_attend(self):
69+
num_embeddings = 100
70+
num_features = 16
71+
batch_size = 2
72+
seq_len = 3
73+
74+
layer = embeddings.Embed(
75+
num_embeddings=num_embeddings,
76+
num_features=num_features,
77+
config=self.cfg,
78+
mesh=self.mesh,
79+
rngs=self.rngs,
80+
)
81+
82+
query = jnp.ones((batch_size, seq_len, num_features))
83+
outputs = layer.attend(query)
84+
85+
self.assertEqual(outputs.shape, (batch_size, seq_len, num_embeddings))
86+
87+
88+
class RotaryEmbeddingTest(unittest.TestCase):
89+
"""Tests for RotaryEmbedding."""
90+
91+
def setUp(self):
92+
super().setUp()
93+
self.rngs = nnx.Rngs(params=0)
94+
95+
config_arguments = {
96+
"per_device_batch_size": 1.0,
97+
"run_name": "test",
98+
"enable_checkpointing": False,
99+
"max_target_length": 128,
100+
}
101+
argv = [sys.argv[0], get_test_config_path()]
102+
self.cfg = pyconfig.initialize(argv, **config_arguments)
103+
104+
devices_array = maxtext_utils.create_device_mesh(self.cfg)
105+
self.mesh = jax.sharding.Mesh(devices_array, self.cfg.mesh_axes)
106+
107+
def test_basic_call(self):
108+
layer = embeddings.RotaryEmbedding(
109+
min_timescale=1,
110+
max_timescale=10000,
111+
mesh=self.mesh,
112+
embedding_dims=4,
113+
rngs=self.rngs,
114+
)
115+
116+
inputs = jnp.ones((1, 2, 1, 4))
117+
position = jnp.array([[0, 1]])
118+
119+
outputs = layer(inputs, position=position)
120+
121+
self.assertEqual(outputs.shape, (1, 2, 1, 4))
122+
123+
# Snapshot verification
124+
expected = jnp.array([[[[1.0, 1.0, 1.0, 1.0]], [[-0.300781, 0.988281, 1.38281, 1.00781]]]])
125+
np.testing.assert_allclose(outputs, expected, atol=1e-5)
126+
127+
128+
class LLaMARotaryEmbeddingTest(unittest.TestCase):
129+
130+
def setUp(self):
131+
super().setUp()
132+
self.rngs = nnx.Rngs(params=0)
133+
134+
config_arguments = {
135+
"per_device_batch_size": 1.0,
136+
"run_name": "test",
137+
"enable_checkpointing": False,
138+
"max_target_length": 128,
139+
}
140+
argv = [sys.argv[0], get_test_config_path()]
141+
self.cfg = pyconfig.initialize(argv, **config_arguments)
142+
143+
devices_array = maxtext_utils.create_device_mesh(self.cfg)
144+
self.mesh = jax.sharding.Mesh(devices_array, self.cfg.mesh_axes)
145+
146+
def test_basic_call(self):
147+
layer = embeddings.LLaMARotaryEmbedding(
148+
min_timescale=1,
149+
max_timescale=10000,
150+
mesh=self.mesh,
151+
embedding_dims=4,
152+
use_scale=True,
153+
rngs=self.rngs,
154+
)
155+
inputs = jnp.ones((1, 2, 1, 4))
156+
position = jnp.array([[0, 1]])
157+
outputs = layer(inputs, position=position)
158+
self.assertEqual(outputs.shape, (1, 2, 1, 4))
159+
160+
# Snapshot verification
161+
expected = jnp.array([[[[1.0, 1.0, 1.0, 1.0]], [[-0.300781, 1.38281, 0.988281, 1.00781]]]])
162+
np.testing.assert_allclose(outputs, expected, atol=1e-5)
163+
164+
165+
class YarnRotaryEmbeddingTest(unittest.TestCase):
166+
167+
def setUp(self):
168+
super().setUp()
169+
self.rngs = nnx.Rngs(params=0)
170+
171+
config_arguments = {
172+
"per_device_batch_size": 1.0,
173+
"run_name": "test",
174+
"enable_checkpointing": False,
175+
"max_target_length": 128,
176+
}
177+
argv = [sys.argv[0], get_test_config_path()]
178+
self.cfg = pyconfig.initialize(argv, **config_arguments)
179+
180+
devices_array = maxtext_utils.create_device_mesh(self.cfg)
181+
self.mesh = jax.sharding.Mesh(devices_array, self.cfg.mesh_axes)
182+
183+
def test_basic_call(self):
184+
layer = embeddings.YarnRotaryEmbedding(
185+
embedding_dims=4,
186+
mesh=self.mesh,
187+
max_position_embeddings=16384,
188+
original_max_position_embeddings=4096,
189+
rngs=self.rngs,
190+
)
191+
inputs = jnp.ones((1, 2, 1, 4))
192+
position = jnp.array([[0, 1]])
193+
outputs = layer(inputs, position=position)
194+
self.assertEqual(outputs.shape, (1, 2, 1, 4))
195+
196+
# Snapshot verification
197+
expected = jnp.array([[[[1.0, 1.0, 1.0, 1.0]], [[-0.300781, 0.996094, 1.38281, 1.00781]]]])
198+
np.testing.assert_allclose(outputs, expected, atol=1e-5)
199+
200+
201+
if __name__ == "__main__":
202+
unittest.main()

tests/unit/partial_rotary_embedding_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,23 @@ def get_attn(pos):
163163
err_msg="PartialRotaryEmbedding attention should be shift-invariant.",
164164
)
165165

166+
def test_snapshot_verification(self):
167+
"""Verify output values against captured snapshot."""
168+
layer = PartialRotaryEmbedding(
169+
min_timescale=1,
170+
max_timescale=10000,
171+
mesh=self.mesh,
172+
embedding_dims=4,
173+
partial_rotary_factor=0.5,
174+
rngs=self.nnx_rng,
175+
)
176+
inputs = jnp.ones((1, 2, 1, 4))
177+
position = jnp.array([[0, 1]])
178+
outputs = layer(inputs, position=position)
179+
180+
expected = jnp.array([[[[1.0, 1.0, 1.0, 1.0]], [[-0.30078125, 1.3828125, 1.0, 1.0]]]])
181+
np.testing.assert_allclose(outputs, expected, atol=1e-5)
182+
166183

167184
class Gemma4PartialRotaryEmbeddingTest(unittest.TestCase):
168185
"""Tests for the Gemma4PartialRotaryEmbedding layer."""
@@ -278,6 +295,23 @@ def get_attn(pos):
278295
err_msg="Gemma4PartialRotaryEmbedding attention should be shift-invariant.",
279296
)
280297

298+
def test_snapshot_verification(self):
299+
"""Verify output values against captured snapshot."""
300+
layer = Gemma4PartialRotaryEmbedding(
301+
min_timescale=1,
302+
max_timescale=10000,
303+
mesh=self.mesh,
304+
embedding_dims=4,
305+
partial_rotary_factor=0.5,
306+
rngs=self.nnx_rng,
307+
)
308+
inputs = jnp.ones((1, 2, 1, 4))
309+
position = jnp.array([[0, 1]])
310+
outputs = layer(inputs, position=position)
311+
312+
expected = jnp.array([[[[1.0, 1.0, 1.0, 1.0]], [[-0.300781, 1.0, 1.38281, 1.0]]]])
313+
np.testing.assert_allclose(outputs, expected, atol=1e-5)
314+
281315

282316
if __name__ == "__main__":
283317
unittest.main()

0 commit comments

Comments
 (0)