Skip to content

Commit 1011fc0

Browse files
author
Sharon Yu
committed
fix comments
1 parent baeb5b6 commit 1011fc0

3 files changed

Lines changed: 30 additions & 28 deletions

File tree

src/MaxText/maxtext_utils.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,21 +1246,18 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
12461246
params = {"params": params}
12471247
if not hasattr(params_sharding, "params"):
12481248
params_sharding = {"params": params_sharding}
1249+
if logical_annotations and not hasattr(logical_annotations, "params"):
1250+
logical_annotations = {"params": logical_annotations}
12491251

12501252
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
12511253
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
1254+
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations.params)
12521255

1253-
leaves_logical = []
1254-
has_logical = False
1255-
if logical_annotations and hasattr(logical_annotations, "params"):
1256-
try:
1257-
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations.params)
1258-
if len(leaves_params) == len(leaves_logical):
1259-
has_logical = True
1260-
else:
1261-
max_logging.warning("Warning: Logical annotations tree structure mismatch. Skipping logical info.")
1262-
except Exception as e: # pylint: disable=broad-exception-caught
1263-
max_logging.warning(f"Warning: Failed to process logical annotations: {e}. Skipping logical info.")
1256+
for i, ((path, leaf_val), (_, leaf_sharding)) in enumerate(zip(leaves_params, leaves_sharding)):
1257+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1258+
shape = jax.typeof(leaf_val)
1259+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1260+
pspec_str = str(tuple(pspec))
12641261

12651262
if not has_logical:
12661263
leaves_logical = [(None, None)] * len(leaves_params)

src/MaxText/model_creation_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def create_sharded_state():
160160
maxtext_utils.print_shardings_params(
161161
params=sharded_state,
162162
params_sharding=out_shardings,
163-
logical_annotations=specs,
164163
mesh=model.mesh,
164+
logical_annotations=specs,
165165
)
166166
if config.load_parameters_path:
167167
try:

src/MaxText/sharding.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,35 +31,28 @@
3131

3232

3333
_LOGGED_ACTIVATION_SHARDINGS = set()
34+
_LOGGED_LOGICAL_AXES = set()
3435

3536

3637
def get_input_data_sharding(config, mesh):
3738
"""Get the input data sharding for the model"""
3839
return create_sharding(mesh, config.input_data_sharding_logical_axes, rules=config.logical_axis_rules)
3940

4041

41-
def maybe_shard_with_name(
42-
inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0, logical_axes=None
43-
):
42+
def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0):
4443
"""
4544
In auto shardmode, this function hints inputs follow given named_sharding.
4645
In explicit shardmode, this function enforces inputs following named_sharding.
4746
"""
4847
if inputs is None:
4948
return None
50-
if debug_sharding and isinstance(inputs, Tracer) and isinstance(named_sharding, NamedSharding):
49+
if (
50+
debug_sharding and isinstance(inputs, Tracer) and isinstance(named_sharding, NamedSharding)
51+
): # only print pspec for JitTracer
5152
pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh"))
52-
if logical_axes is not None:
53-
logical_str = str(logical_axes)
54-
else:
55-
logical_str = "None"
56-
shape_str = str(jax.typeof(inputs))
57-
log_key = (shape_str, tuple(pspec), extra_stack_level, logical_str)
58-
53+
log_key = (str(jax.typeof(inputs)), tuple(pspec), extra_stack_level)
5954
if log_key not in _LOGGED_ACTIVATION_SHARDINGS:
60-
max_logging.info(
61-
f"Activation: {logical_str:<40} -> {str(tuple(pspec)):<30} {shape_str}", stacklevel=3 + extra_stack_level
62-
)
55+
max_logging.info(f"{log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level)
6356
_LOGGED_ACTIVATION_SHARDINGS.add(log_key)
6457
if shard_mode == ShardMode.EXPLICIT:
6558
return reshard(inputs, named_sharding)
@@ -75,14 +68,26 @@ def maybe_shard_with_logical(
7568
"""
7669
if inputs is None:
7770
return None
71+
7872
named_sharding = create_sharding(mesh, logical_axes, rules=rules)
73+
74+
if debug_sharding and isinstance(inputs, Tracer):
75+
log_key = (str(jax.typeof(inputs)), logical_axes, extra_stack_level)
76+
77+
if log_key not in _LOGGED_LOGICAL_AXES:
78+
pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh"))
79+
pspec_str = str(tuple(pspec)) if pspec else "None"
80+
81+
max_logging.info(f"Logical: {log_key[0]:.<60} {log_key[1]}", stacklevel=3 + extra_stack_level)
82+
max_logging.info(f"{log_key[0]:.<80} {pspec_str}.", stacklevel=3 + extra_stack_level)
83+
_LOGGED_LOGICAL_AXES.add(log_key)
84+
7985
return maybe_shard_with_name(
8086
inputs,
8187
named_sharding,
8288
shard_mode,
83-
debug_sharding=debug_sharding,
89+
debug_sharding=False,
8490
extra_stack_level=extra_stack_level + 1,
85-
logical_axes=logical_axes,
8691
)
8792

8893

0 commit comments

Comments
 (0)