Skip to content

Commit 00a5915

Browse files
committed
Fix: Fixed UT jax.make_mesh should take jax.sharding.AxisType.Auto in tests
1 parent a17a295 commit 00a5915

2 files changed

Lines changed: 4 additions & 2 deletions

File tree

axlearn/common/kv_cache/kv_cache_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax
66
import jax.numpy as jnp
77
import pytest
8+
from jax.sharding import AxisType
89
from absl.testing import absltest, parameterized
910
from jax.sharding import PartitionSpec
1011

@@ -173,7 +174,7 @@ def test_init_states_kv_partition_spec(self):
173174
)
174175
shape = KVCache.Shape(batch_size=batch, kv_len=kv_len, num_kv_heads=heads, per_head_dim=dim)
175176

176-
with jax.make_mesh((4, 2), ("data", "model")):
177+
with jax.make_mesh((4, 2), ("data", "model"), axis_types=(AxisType.Auto, AxisType.Auto)):
177178

178179
@jax.jit
179180
def f():

axlearn/common/kv_cache/sliding_window_kv_cache_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import jax
66
import jax.numpy as jnp
7+
from jax.sharding import AxisType
78
import pytest
89
from absl.testing import absltest, parameterized
910
from jax.sharding import PartitionSpec
@@ -131,7 +132,7 @@ def test_init_states_kv_partition_spec(self):
131132
batch_size=batch, kv_len=32, num_kv_heads=heads, per_head_dim=dim
132133
)
133134

134-
with jax.make_mesh((4, 2), ("data", "model")):
135+
with jax.make_mesh((4, 2), ("data", "model"), axis_types=(AxisType.Auto, AxisType.Auto)):
135136

136137
@jax.jit
137138
def f():

0 commit comments

Comments
 (0)