|
25 | 25 | from jax import random |
26 | 26 |
|
27 | 27 | from flax import linen as nn |
| 28 | +from flax import nnx |
28 | 29 |
|
29 | 30 | import optax |
30 | 31 |
|
@@ -167,8 +168,16 @@ def test_unscan_train_state_params(self): |
167 | 168 | num_layers = config.base_num_decoder_layers |
168 | 169 |
|
169 | 170 | # Make a copy to unscan, leaving the original state intact. |
170 | | - params_to_unscan = jax.tree_util.tree_map(lambda x: x, state.params) |
171 | | - sharding_to_unscan = jax.tree_util.tree_map(lambda x: x, sharding.params) |
| 171 | + if hasattr(state, "model"): |
| 172 | + _, params_state, _ = nnx.split(state.model, nnx.Param, ...) |
| 173 | + params_to_unscan = {"params": params_state.to_pure_dict()} |
| 174 | + else: |
| 175 | + params_to_unscan = jax.tree_util.tree_map(lambda x: x, state.params) |
| 176 | + if hasattr(sharding, "model"): |
| 177 | + _, sharding_params, _ = nnx.split(sharding.model, nnx.Param, ...) |
| 178 | + sharding_to_unscan = {"params": sharding_params.to_pure_dict()} |
| 179 | + else: |
| 180 | + sharding_to_unscan = jax.tree_util.tree_map(lambda x: x, sharding.params) |
172 | 181 |
|
173 | 182 | # Time the unscan operation. |
174 | 183 | start_time = time.time() |
@@ -196,8 +205,16 @@ def test_unscan_train_state_params(self): |
196 | 205 | self.assertEqual(unstacked_shape, expected_shape) |
197 | 206 |
|
198 | 207 | # Check that the original state is unchanged. |
199 | | - self.assertIn("layers", state.params["params"]["decoder"]) |
200 | | - self.assertNotIn("layers_0", state.params["params"]["decoder"]) |
| 208 | + if hasattr(state, "model"): |
| 209 | + _, params_state, _ = nnx.split(state.model, nnx.Param, ...) |
| 210 | + state_decoder_params = params_state.to_pure_dict()["decoder"] |
| 211 | + self.assertIn("layers", state_decoder_params) |
| 212 | + else: |
| 213 | + self.assertIn("layers", state.params["params"]["decoder"]) |
| 214 | + if hasattr(state, "model"): |
| 215 | + self.assertNotIn("layers_0", state_decoder_params) |
| 216 | + else: |
| 217 | + self.assertNotIn("layers_0", state.params["params"]["decoder"]) |
201 | 218 |
|
202 | 219 |
|
203 | 220 | class TestGpuDistributedInitialization(unittest.TestCase): |
|
0 commit comments