Skip to content

Commit b5cf469

Browse files
author
maxtext authors
committed
Merge pull request #1671 from AI-Hypercomputer:mattdavidow-remove-shmap-test
PiperOrigin-RevId: 754072676
2 parents 5759cdb + f202216 commit b5cf469

3 files changed

Lines changed: 25 additions & 14 deletions

File tree

MaxText/tests/integration_tests/shmap_collective_matmul_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@
2222
from MaxText.globals import PKG_DIR
2323

2424
sys.path.append(os.path.join(os.path.dirname(PKG_DIR), "pedagogical_examples"))
25-
from pedagogical_examples.shmap_collective_matmul import main
2625

26+
# Uncomment the import when b/415022795 is fixed
27+
#from pedagogical_examples.shmap_collective_matmul import main
2728

29+
30+
@pytest.mark.skip(reason="Enable when b/415022795 is fixed")
2831
@pytest.mark.integration_test
2932
@pytest.mark.tpu_only
3033
def test_shmap_collective_matmul_example():
3134
"""Validate Pedagogical Example, Shmap_collective_matmul."""
32-
33-
assert main() is True
35+
# Uncomment main() assertion when b/415022795 is fixed
36+
#assert main() is True

MaxText/tests/kernels_test.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@ class RaggedAttentionTest(unittest.TestCase):
3535
head_dim = 128
3636

3737
dtype = jnp.float32
38-
key = jax.random.key(0)
39-
k1, k2, k3 = jax.random.split(key, 3)
4038

4139
@pytest.mark.tpu_only
4240
def test_ragged_mqa(self):
43-
q = jax.random.normal(self.k1, (self.batch_size, 1, self.head_dim), dtype=self.dtype)
44-
k = jax.random.normal(self.k2, (self.batch_size, self.max_target_length, self.head_dim), dtype=self.dtype)
45-
v = jax.random.normal(self.k3, (self.batch_size, self.max_target_length, self.head_dim), dtype=self.dtype)
41+
key = jax.random.key(0)
42+
k1, k2, k3 = jax.random.split(key, 3)
43+
44+
q = jax.random.normal(k1, (self.batch_size, 1, self.head_dim), dtype=self.dtype)
45+
k = jax.random.normal(k2, (self.batch_size, self.max_target_length, self.head_dim), dtype=self.dtype)
46+
v = jax.random.normal(k3, (self.batch_size, self.max_target_length, self.head_dim), dtype=self.dtype)
4647
lengths = jnp.array(np.random.randint(1, self.max_target_length, self.batch_size), dtype=jnp.int32)
4748

4849
ragged_out, ragged_max, ragged_denom = ragged_mqa(q, k, v, lengths)
@@ -58,12 +59,15 @@ def test_ragged_mqa(self):
5859

5960
@pytest.mark.tpu_only
6061
def test_ragged_mha(self):
61-
q = jax.random.normal(self.k1, (self.batch_size, 1, self.num_query_heads, self.head_dim), dtype=self.dtype)
62+
key = jax.random.key(0)
63+
k1, k2, k3 = jax.random.split(key, 3)
64+
65+
q = jax.random.normal(k1, (self.batch_size, 1, self.num_query_heads, self.head_dim), dtype=self.dtype)
6266
k = jax.random.normal(
63-
self.k2, (self.batch_size, self.max_target_length, self.num_query_heads, self.head_dim), dtype=self.dtype
67+
k2, (self.batch_size, self.max_target_length, self.num_query_heads, self.head_dim), dtype=self.dtype
6468
)
6569
v = jax.random.normal(
66-
self.k3, (self.batch_size, self.max_target_length, self.num_query_heads, self.head_dim), dtype=self.dtype
70+
k3, (self.batch_size, self.max_target_length, self.num_query_heads, self.head_dim), dtype=self.dtype
6771
)
6872
lengths = jnp.array(np.random.randint(1, self.max_target_length, self.batch_size), dtype=jnp.int32)
6973

@@ -81,12 +85,15 @@ def test_ragged_mha(self):
8185

8286
@pytest.mark.tpu_only
8387
def test_ragged_gqa(self):
84-
q = jax.random.normal(self.k1, (self.batch_size, 1, self.num_query_heads, self.head_dim), dtype=self.dtype)
88+
key = jax.random.key(0)
89+
k1, k2, k3 = jax.random.split(key, 3)
90+
91+
q = jax.random.normal(k1, (self.batch_size, 1, self.num_query_heads, self.head_dim), dtype=self.dtype)
8592
k = jax.random.normal(
86-
self.k2, (self.batch_size, self.max_target_length, self.num_kv_heads, self.head_dim), dtype=self.dtype
93+
k2, (self.batch_size, self.max_target_length, self.num_kv_heads, self.head_dim), dtype=self.dtype
8794
)
8895
v = jax.random.normal(
89-
self.k3, (self.batch_size, self.max_target_length, self.num_kv_heads, self.head_dim), dtype=self.dtype
96+
k3, (self.batch_size, self.max_target_length, self.num_kv_heads, self.head_dim), dtype=self.dtype
9097
)
9198
lengths = jnp.array(np.random.randint(1, self.max_target_length, self.batch_size), dtype=jnp.int32)
9299

pedagogical_examples/shmap_collective_matmul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
MESH_FSDP_AXIS = "fsdp"
3636
MESH_TENSOR_AXIS = "tp"
3737

38+
# We should not call jax.devices() when this file is imported b/415022795.
3839
d = jax.devices()
3940
outd = [[[d[0], d[1], d[3], d[2]]]]
4041
global_mesh = Mesh(outd, (MESH_DATA_AXIS, MESH_FSDP_AXIS, MESH_TENSOR_AXIS))

0 commit comments

Comments
 (0)