Skip to content

Commit 2c8efe0

Browse files
hsuan-lun-chiangecnal-cienet
authored andcommitted
tests/unit/max_utils_test.py::UnscanTest::test_unscan_train_state_params
1 parent 044c33c commit 2c8efe0

1 file changed

Lines changed: 21 additions & 4 deletions

File tree

tests/unit/max_utils_test.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from jax import random
2626

2727
from flax import linen as nn
28+
from flax import nnx
2829

2930
import optax
3031

@@ -167,8 +168,16 @@ def test_unscan_train_state_params(self):
167168
num_layers = config.base_num_decoder_layers
168169

169170
# Make a copy to unscan, leaving the original state intact.
170-
params_to_unscan = jax.tree_util.tree_map(lambda x: x, state.params)
171-
sharding_to_unscan = jax.tree_util.tree_map(lambda x: x, sharding.params)
171+
if hasattr(state, "model"):
172+
_, params_state, _ = nnx.split(state.model, nnx.Param, ...)
173+
params_to_unscan = {"params": params_state.to_pure_dict()}
174+
else:
175+
params_to_unscan = jax.tree_util.tree_map(lambda x: x, state.params)
176+
if hasattr(sharding, "model"):
177+
_, sharding_params, _ = nnx.split(sharding.model, nnx.Param, ...)
178+
sharding_to_unscan = {"params": sharding_params.to_pure_dict()}
179+
else:
180+
sharding_to_unscan = jax.tree_util.tree_map(lambda x: x, sharding.params)
172181

173182
# Time the unscan operation.
174183
start_time = time.time()
@@ -196,8 +205,16 @@ def test_unscan_train_state_params(self):
196205
self.assertEqual(unstacked_shape, expected_shape)
197206

198207
# Check that the original state is unchanged.
199-
self.assertIn("layers", state.params["params"]["decoder"])
200-
self.assertNotIn("layers_0", state.params["params"]["decoder"])
208+
if hasattr(state, "model"):
209+
_, params_state, _ = nnx.split(state.model, nnx.Param, ...)
210+
state_decoder_params = params_state.to_pure_dict()["decoder"]
211+
self.assertIn("layers", state_decoder_params)
212+
else:
213+
self.assertIn("layers", state.params["params"]["decoder"])
214+
if hasattr(state, "model"):
215+
self.assertNotIn("layers_0", state_decoder_params)
216+
else:
217+
self.assertNotIn("layers_0", state.params["params"]["decoder"])
201218

202219

203220
class TestGpuDistributedInitialization(unittest.TestCase):

0 commit comments

Comments
 (0)