Skip to content

Commit 4452182

Browse files
xibinliuecnal-cienet
authored andcommitted
NNX migration: modify the print_shardings_params to support NNX
1 parent 5dd880f commit 4452182

2 files changed

Lines changed: 37 additions & 18 deletions

File tree

src/maxtext/utils/maxtext_utils.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,26 +1653,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
16531653
"""
16541654
Print state shardings comparing Logical Definition vs Physical Result.
16551655
"""
1656-
if not hasattr(params, "params"):
1657-
params = {"params": params}
1658-
if not hasattr(params_sharding, "params"):
1659-
params_sharding = {"params": params_sharding}
1660-
if logical_annotations and not hasattr(logical_annotations, "params"):
1661-
logical_annotations = {"params": logical_annotations}
1656+
if not isinstance(params, nnx.State):
1657+
if not hasattr(params, "params"):
1658+
params = {"params": params}
1659+
if not hasattr(params_sharding, "params"):
1660+
params_sharding = {"params": params_sharding}
1661+
if logical_annotations and not hasattr(logical_annotations, "params"):
1662+
logical_annotations = {"params": logical_annotations}
16621663

16631664
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
16641665
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
1665-
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
16661666

1667-
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical):
1668-
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1669-
shape = jax.typeof(leaf_val)
1670-
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1671-
pspec_str = str(tuple(pspec))
1672-
logical_str = str(leaf_logical_val)
1673-
1674-
message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1675-
max_logging.info(message)
1667+
if logical_annotations is not None:
1668+
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
1669+
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(
1670+
leaves_params, leaves_sharding, leaves_logical
1671+
):
1672+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1673+
shape = jax.typeof(leaf_val)
1674+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1675+
pspec_str = str(tuple(pspec))
1676+
logical_str = str(leaf_logical_val)
1677+
1678+
message = (
1679+
f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1680+
)
1681+
max_logging.info(message)
1682+
else:
1683+
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1684+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1685+
shape = jax.typeof(leaf_val)
1686+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1687+
pspec_str = str(tuple(pspec))
1688+
1689+
message = f" {path_str}\n" f" Shape: {shape}\n" f" Physical: {pspec_str}"
1690+
max_logging.info(message)
16761691

16771692
print(flush=True)
16781693

tests/unit/maxtext_utils_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def setUp(self):
180180
},
181181
"decoder": {"gate": {"bias": jnp.array([0.5, 0.5])}},
182182
}
183-
self.state = train_state.TrainState(step=0, apply_fn=self.model.apply, params=self.initial_params, tx=None, opt_state={})
183+
self.state = train_state.TrainState(
184+
step=0, apply_fn=self.model.apply, params=self.initial_params, tx=None, opt_state={}
185+
)
184186

185187
def test_update_mode_add(self):
186188
target_path = ("decoder", "gate", "bias")
@@ -721,7 +723,9 @@ def test_low_temperature_is_greedy(self):
721723
rngs = jax.random.split(self.rng, 10)
722724

723725
for r in rngs:
724-
token = inference_utils.sample_topk_topp_weighted(self.logits, topk=10, nucleus_topp=1.0, temperature=low_temp, rng=r)
726+
token = inference_utils.sample_topk_topp_weighted(
727+
self.logits, topk=10, nucleus_topp=1.0, temperature=low_temp, rng=r
728+
)
725729
self.assertEqual(token.item(), greedy_token_index)
726730

727731
def test_invalid_args_raise_error(self):

0 commit comments

Comments
 (0)