Skip to content

Commit b153e8c

Browse files
xibinliuecnal-cienet
authored andcommitted
NNX migration: modify the print_shardings_params to support NNX
1 parent 4a574ea commit b153e8c

2 files changed

Lines changed: 37 additions & 18 deletions

File tree

src/maxtext/utils/maxtext_utils.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,26 +1563,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
15631563
"""
15641564
Print state shardings comparing Logical Definition vs Physical Result.
15651565
"""
1566-
if not hasattr(params, "params"):
1567-
params = {"params": params}
1568-
if not hasattr(params_sharding, "params"):
1569-
params_sharding = {"params": params_sharding}
1570-
if logical_annotations and not hasattr(logical_annotations, "params"):
1571-
logical_annotations = {"params": logical_annotations}
1566+
if not isinstance(params, nnx.State):
1567+
if not hasattr(params, "params"):
1568+
params = {"params": params}
1569+
if not hasattr(params_sharding, "params"):
1570+
params_sharding = {"params": params_sharding}
1571+
if logical_annotations and not hasattr(logical_annotations, "params"):
1572+
logical_annotations = {"params": logical_annotations}
15721573

15731574
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
15741575
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
1575-
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
15761576

1577-
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical):
1578-
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1579-
shape = jax.typeof(leaf_val)
1580-
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1581-
pspec_str = str(tuple(pspec))
1582-
logical_str = str(leaf_logical_val)
1583-
1584-
message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1585-
max_logging.info(message)
1577+
if logical_annotations is not None:
1578+
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
1579+
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(
1580+
leaves_params, leaves_sharding, leaves_logical
1581+
):
1582+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1583+
shape = jax.typeof(leaf_val)
1584+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1585+
pspec_str = str(tuple(pspec))
1586+
logical_str = str(leaf_logical_val)
1587+
1588+
message = (
1589+
f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1590+
)
1591+
max_logging.info(message)
1592+
else:
1593+
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1594+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1595+
shape = jax.typeof(leaf_val)
1596+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1597+
pspec_str = str(tuple(pspec))
1598+
1599+
message = f" {path_str}\n" f" Shape: {shape}\n" f" Physical: {pspec_str}"
1600+
max_logging.info(message)
15861601

15871602
print(flush=True)
15881603

tests/unit/maxtext_utils_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ def setUp(self):
182182
},
183183
"decoder": {"gate": {"bias": jnp.array([0.5, 0.5])}},
184184
}
185-
self.state = train_state.TrainState(step=0, apply_fn=self.model.apply, params=self.initial_params, tx=None, opt_state={})
185+
self.state = train_state.TrainState(
186+
step=0, apply_fn=self.model.apply, params=self.initial_params, tx=None, opt_state={}
187+
)
186188

187189
def test_update_mode_add(self):
188190
target_path = ("decoder", "gate", "bias")
@@ -723,7 +725,9 @@ def test_low_temperature_is_greedy(self):
723725
rngs = jax.random.split(self.rng, 10)
724726

725727
for r in rngs:
726-
token = inference_utils.sample_topk_topp_weighted(self.logits, topk=10, nucleus_topp=1.0, temperature=low_temp, rng=r)
728+
token = inference_utils.sample_topk_topp_weighted(
729+
self.logits, topk=10, nucleus_topp=1.0, temperature=low_temp, rng=r
730+
)
727731
self.assertEqual(token.item(), greedy_token_index)
728732

729733
def test_invalid_args_raise_error(self):

0 commit comments

Comments
 (0)