Skip to content

Commit 6591621

Browse files
author
Sharon Yu
committed
print out logic axes
1 parent d01e65d commit 6591621

7 files changed

Lines changed: 142 additions & 15 deletions

File tree

src/MaxText/max_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,3 +1032,14 @@ def transformer_engine_context():
10321032
yield
10331033
except (ImportError, AttributeError):
10341034
yield
1035+
1036+
1037+
def print_mesh_axes_info(mesh: jax.sharding.Mesh):
1038+
"""Prints all mesh axes and their sizes in a single comma-separated line."""
1039+
if not mesh.shape:
1040+
max_logging.info("Mesh Axes: (Empty Mesh)")
1041+
return
1042+
1043+
axis_info = [f"{axis_name}: {axis_size}" for axis_name, axis_size in mesh.shape.items()]
1044+
info_str = "Mesh Axes: (" + ", ".join(axis_info) + ")"
1045+
max_logging.info(info_str)

src/MaxText/maxtext_utils.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import functools
1919
import pickle
2020

21+
from collections import defaultdict
22+
2123
from flax import linen as nn
2224
from flax.linen import partitioning as nn_partitioning
2325
from flax.training import train_state
@@ -26,6 +28,7 @@
2628

2729
from jax.experimental import mesh_utils
2830
from jax.experimental.serialize_executable import deserialize_and_load
31+
from jax.sharding import PartitionSpec as P
2932

3033
import jax
3134
import jax.numpy as jnp
@@ -1204,12 +1207,99 @@ def schedule(step):
12041207
return optax.join_schedules(pieces, boundaries)
12051208

12061209

1207-
def print_shardings_params(params, params_sharding, mesh):
1210+
def print_shardings_params(params, params_sharding, mesh, state_logical_annotations=None, logical_axis_rules=None):
12081211
"""Print state shardings."""
12091212
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
12101213
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
1211-
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1214+
1215+
leaves_rule_values = []
1216+
if state_logical_annotations and hasattr(state_logical_annotations, "params"):
1217+
leaves_rule_values, _ = jax.tree_util.tree_flatten_with_path(state_logical_annotations.params)
1218+
else:
1219+
leaves_rule_values = [(None, None)] * len(leaves_params)
1220+
1221+
if not len(leaves_params) == len(leaves_sharding) == len(leaves_rule_values):
1222+
max_logging.warning(
1223+
"Warning: Parameter tree structure mismatch between params, shardings," " and logical annotations."
1224+
)
1225+
return
1226+
1227+
# Build a reverse map
1228+
rule_value_to_semantic = defaultdict(list)
1229+
if logical_axis_rules:
1230+
rules_iter = logical_axis_rules.items() if isinstance(logical_axis_rules, dict) else logical_axis_rules
1231+
for name, potentials in rules_iter:
1232+
if isinstance(potentials, str):
1233+
key = (potentials,)
1234+
elif potentials is None:
1235+
key = (None,)
1236+
elif isinstance(potentials, list):
1237+
key = tuple(potentials)
1238+
elif isinstance(potentials, tuple):
1239+
key = potentials
1240+
else:
1241+
key = (potentials,)
1242+
1243+
key = tuple(p for p in key)
1244+
rule_value_to_semantic[key].append(name)
1245+
1246+
# Header for the entire block (
1247+
max_logging.info("Parameter Path")
1248+
max_logging.info("Shape")
1249+
max_logging.info("Logical Axes")
1250+
max_logging.info("Physical PartitionSpec")
1251+
max_logging.info("-" * 120)
1252+
1253+
for (path, leaf_val), (_, leaf_sharding), (_, leaf_rule_value) in zip(
1254+
leaves_params, leaves_sharding, leaves_rule_values
1255+
):
12121256
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
12131257
shape = jax.typeof(leaf_val)
1214-
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1215-
max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}")
1258+
# Physical PartitionSpec
1259+
pspec_str = "N/A"
1260+
if hasattr(leaf_sharding, "spec"):
1261+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1262+
pspec_str = str(tuple(pspec))
1263+
elif leaf_sharding is not None:
1264+
pspec_str = str(leaf_sharding)
1265+
1266+
def get_semantic_names(rule_val_item, rmap):
1267+
if rule_val_item is None:
1268+
key = (None,)
1269+
elif isinstance(rule_val_item, str):
1270+
key = (rule_val_item,)
1271+
elif isinstance(rule_val_item, tuple):
1272+
key = rule_val_item
1273+
else:
1274+
return f"'{str(rule_val_item)}'"
1275+
1276+
key = tuple(p for p in key)
1277+
names = rmap.get(key)
1278+
1279+
if names:
1280+
return "{" + " | ".join(sorted(list(set(names)))) + "}"
1281+
else:
1282+
return f"'{str(key)}'"
1283+
1284+
# Logical Axes string representation
1285+
semantic_str = "N/A"
1286+
if leaf_rule_value is not None:
1287+
if isinstance(leaf_rule_value, P):
1288+
semantic_parts = []
1289+
for s in leaf_rule_value:
1290+
name_str = get_semantic_names(s, rule_value_to_semantic)
1291+
semantic_parts.append(str(name_str))
1292+
semantic_str = "Partitionspec(" + ", ".join(semantic_parts) + ")"
1293+
elif leaf_rule_value is None: # Explicit None in rule
1294+
semantic_str = "Partitionspec(None)"
1295+
else:
1296+
semantic_str = str(leaf_rule_value)
1297+
else:
1298+
semantic_str = "None (No Logical Info)"
1299+
1300+
# Multi-line logging
1301+
max_logging.info(f"{path_str}")
1302+
max_logging.info(f"{shape}")
1303+
max_logging.info(f"{semantic_str}")
1304+
max_logging.info(f"{pspec_str}")
1305+
max_logging.info("-" * 120)

src/MaxText/model_creation_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,13 @@ def create_sharded_state():
157157
# print weights sharding info under debug sharding mode
158158
if config.debug_sharding:
159159
max_utils.print_non_trivial_mesh_axis(model.mesh)
160-
maxtext_utils.print_shardings_params(sharded_state, out_shardings, model.mesh)
160+
maxtext_utils.print_state_mesh_shardings_params(
161+
state=sharded_state,
162+
state_sharding=out_shardings,
163+
state_logical_annotations=specs,
164+
mesh=model.mesh,
165+
logical_axis_rules=config.logical_axis_rules,
166+
)
161167
if config.load_parameters_path:
162168
try:
163169
ckptr = ocp.Checkpointer(

src/MaxText/train_compile.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def get_shaped_inputs(topology_mesh, config):
100100
shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype)
101101

102102
# Shaped state
103-
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(
103+
abstract_state, state_logical_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state(
104104
model, tx, config, example_rng, topology_mesh
105105
)
106106

@@ -109,7 +109,7 @@ def get_shaped_inputs(topology_mesh, config):
109109

110110
shaped_train_args = (abstract_state, shaped_batch, shaped_rng)
111111
shaped_train_kwargs = {}
112-
return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model
112+
return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, state_logical_annotations, model
113113

114114

115115
def jit_and_compile(
@@ -158,7 +158,13 @@ def is_oom(argv: Sequence[str]) -> bool:
158158
max_utils.print_system_information()
159159

160160
# Get shaped inputs
161-
shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config)
161+
(
162+
shaped_train_args,
163+
shaped_train_kwargs,
164+
state_mesh_shardings,
165+
_,
166+
model,
167+
) = get_shaped_inputs(topology_mesh, config)
162168

163169
# Get data sharding
164170
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
@@ -213,7 +219,13 @@ def main(argv: Sequence[str]) -> None:
213219
max_utils.print_system_information()
214220

215221
# Get shaped inputs
216-
shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config)
222+
(
223+
shaped_train_args,
224+
shaped_train_kwargs,
225+
state_mesh_shardings,
226+
state_logical_annotations,
227+
model,
228+
) = get_shaped_inputs(topology_mesh, config)
217229

218230
# Get data sharding
219231
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
@@ -227,8 +239,14 @@ def main(argv: Sequence[str]) -> None:
227239

228240
# print weights sharding info under debug sharding mode
229241
if config.debug_sharding:
230-
max_utils.print_non_trivial_mesh_axis(topology_mesh)
231-
maxtext_utils.print_shardings_params(shaped_train_args[0].params, state_mesh_shardings.params, topology_mesh)
242+
max_utils.print_mesh_axes_info(topology_mesh)
243+
maxtext_utils.print_shardings_params(
244+
shaped_train_args[0].params,
245+
state_mesh_shardings.params,
246+
topology_mesh,
247+
state_logical_annotations,
248+
config.logical_axis_rules,
249+
)
232250

233251
# Compile
234252
print("Jitting and compiling train step...", flush=True)

src/MaxText/train_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def setup_train_loop(config, recorder, devices=None):
207207
eval_data_iterator,
208208
)
209209

210-
state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
210+
state, state_mesh_annotations, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
211211
model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager
212212
)
213213

@@ -219,7 +219,9 @@ def setup_train_loop(config, recorder, devices=None):
219219
# print weights sharding info under debug sharding mode
220220
if config.debug_sharding:
221221
max_utils.print_non_trivial_mesh_axis(model.mesh)
222-
maxtext_utils.print_shardings_params(state.params, state_mesh_shardings.params, model.mesh)
222+
maxtext_utils.print_shardings_params(
223+
state.params, state_mesh_shardings.params, model.mesh, state_mesh_annotations, config.logical_axis_rules
224+
)
223225

224226
if config.use_dpo:
225227
abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)

tests/unit/sharding_compare_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str)
9797
validate_config(config)
9898

9999
topology_mesh = get_topology_mesh(config)
100-
_, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config)
100+
_, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config)
101101
actual_json = named_shardings_to_json(state_mesh_shardings)
102102
expected_json = load_named_sharding_json(json_path)
103103

tests/utils/sharding_dump.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def main(argv: Sequence[str]) -> None:
276276

277277
try:
278278
topology_mesh = get_topology_mesh(config)
279-
_, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config)
279+
_, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config)
280280
except: # pylint: disable=bare-except
281281
state_mesh_shardings = {}
282282

0 commit comments

Comments
 (0)