Skip to content

Commit f1e2f02

Browse files
Merge pull request #3341 from ROCm:ici-parallelism-PartialRotaryEmbeddingTest
PiperOrigin-RevId: 881501786
2 parents 81bbd64 + 0d509e2 commit f1e2f02

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

tests/unit/partial_rotary_embedding_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from maxtext.layers.embeddings import PartialRotaryEmbedding, RotaryEmbedding
3434
from maxtext.configs import pyconfig
3535
from maxtext.utils import maxtext_utils
36-
from tests.utils.test_helpers import get_test_config_path
36+
from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides
3737

3838

3939
class PartialRotaryEmbeddingTest(unittest.TestCase):
@@ -42,10 +42,12 @@ class PartialRotaryEmbeddingTest(unittest.TestCase):
4242
def setUp(self):
4343
super().setUp()
4444
# build a simple config and mesh like other embedding tests
45+
extra_args = get_decoupled_parallelism_overrides()
4546
self.cfg = pyconfig.initialize(
4647
[sys.argv[0], get_test_config_path()],
4748
run_name="test_embeddings",
4849
enable_checkpointing=False,
50+
**extra_args,
4951
)
5052
devices_array = maxtext_utils.create_device_mesh(self.cfg)
5153
self.mesh = Mesh(devices_array, self.cfg.mesh_axes)

0 commit comments

Comments
 (0)