Skip to content

Commit 23097b9

Browse files
xibinliuecnal-cienet
authored andcommitted
NNX migration: modify the print_shardings_params to support NNX
1 parent e8fb6e6 commit 23097b9

2 files changed

Lines changed: 874 additions & 20 deletions

File tree

src/maxtext/utils/maxtext_utils.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,26 +1563,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
15631563
"""
15641564
Print state shardings comparing Logical Definition vs Physical Result.
15651565
"""
1566-
if not hasattr(params, "params"):
1567-
params = {"params": params}
1568-
if not hasattr(params_sharding, "params"):
1569-
params_sharding = {"params": params_sharding}
1570-
if logical_annotations and not hasattr(logical_annotations, "params"):
1571-
logical_annotations = {"params": logical_annotations}
1566+
if not isinstance(params, nnx.State):
1567+
if not hasattr(params, "params"):
1568+
params = {"params": params}
1569+
if not hasattr(params_sharding, "params"):
1570+
params_sharding = {"params": params_sharding}
1571+
if logical_annotations and not hasattr(logical_annotations, "params"):
1572+
logical_annotations = {"params": logical_annotations}
15721573

15731574
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
15741575
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
1575-
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
15761576

1577-
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical):
1578-
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1579-
shape = jax.typeof(leaf_val)
1580-
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1581-
pspec_str = str(tuple(pspec))
1582-
logical_str = str(leaf_logical_val)
1583-
1584-
message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1585-
max_logging.info(message)
1577+
if logical_annotations is not None:
1578+
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
1579+
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(
1580+
leaves_params, leaves_sharding, leaves_logical
1581+
):
1582+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1583+
shape = jax.typeof(leaf_val)
1584+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1585+
pspec_str = str(tuple(pspec))
1586+
logical_str = str(leaf_logical_val)
1587+
1588+
message = (
1589+
f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1590+
)
1591+
max_logging.info(message)
1592+
else:
1593+
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1594+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1595+
shape = jax.typeof(leaf_val)
1596+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1597+
pspec_str = str(tuple(pspec))
1598+
1599+
message = f" {path_str}\n" f" Shape: {shape}\n" f" Physical: {pspec_str}"
1600+
max_logging.info(message)
15861601

15871602
print(flush=True)
15881603

0 commit comments

Comments
 (0)