File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 3333from maxtext .layers .embeddings import PartialRotaryEmbedding , RotaryEmbedding
3434from maxtext .configs import pyconfig
3535from maxtext .utils import maxtext_utils
36- from tests .utils .test_helpers import get_test_config_path
36+ from tests .utils .test_helpers import get_test_config_path , get_decoupled_parallelism_overrides
3737
3838
3939class PartialRotaryEmbeddingTest (unittest .TestCase ):
@@ -42,10 +42,12 @@ class PartialRotaryEmbeddingTest(unittest.TestCase):
4242 def setUp (self ):
4343 super ().setUp ()
4444 # build a simple config and mesh like other embedding tests
45+ extra_args = get_decoupled_parallelism_overrides ()
4546 self .cfg = pyconfig .initialize (
4647 [sys .argv [0 ], get_test_config_path ()],
4748 run_name = "test_embeddings" ,
4849 enable_checkpointing = False ,
50+ ** extra_args ,
4951 )
5052 devices_array = maxtext_utils .create_device_mesh (self .cfg )
5153 self .mesh = Mesh (devices_array , self .cfg .mesh_axes )
You can’t perform that action at this time.
0 commit comments