Skip to content

Commit d0644bc

Browse files
author
Sharon Yu
committed
resolve conflict
1 parent 1011fc0 commit d0644bc

1 file changed

Lines changed: 5 additions & 22 deletions

File tree

src/MaxText/maxtext_utils.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,34 +1253,17 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
12531253
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
12541254
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations.params)
12551255

1256-
for i, ((path, leaf_val), (_, leaf_sharding)) in enumerate(zip(leaves_params, leaves_sharding)):
1256+
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical):
12571257
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
12581258
shape = jax.typeof(leaf_val)
12591259
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
12601260
pspec_str = str(tuple(pspec))
1261+
logical_str = str(leaf_logical_val)
12611262

1262-
if not has_logical:
1263-
leaves_logical = [(None, None)] * len(leaves_params)
1263+
message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1264+
max_logging.info(message)
12641265

1265-
if len(leaves_params) != len(leaves_sharding):
1266-
max_logging.warning("Warning: Params and Sharding tree mismatch.")
1267-
return
1268-
1269-
for i, (path, leaf_val) in enumerate(leaves_params):
1270-
_, leaf_sharding = leaves_sharding[i]
1271-
leaf_logical_val = leaves_logical[i][1] if has_logical else None
1272-
1273-
path_str = "/".join(str(p.key if hasattr(p, "key") else getattr(p, "name", "?")) for p in path)
1274-
1275-
shape = str(jax.typeof(leaf_val))
1276-
1277-
for (path, leaf_val), (_, leaf_sharding), (_, leaf_rule_value) in zip(
1278-
leaves_params, leaves_sharding, leaves_rule_values
1279-
):
1280-
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1281-
shape = jax.typeof(leaf_val)
1282-
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1283-
max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}")
1266+
print(flush=True)
12841267

12851268

12861269
def maybe_dump_jaxpr(config, p_train_step, train_step_inputs):

0 commit comments

Comments
 (0)