Skip to content

Commit a8f1147

Browse files
author
Sharon Yu
committed
print out logic axes
1 parent 1137c42 commit a8f1147

6 files changed

Lines changed: 153 additions & 14 deletions

File tree

src/MaxText/max_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,3 +1032,17 @@ def transformer_engine_context():
10321032
yield
10331033
except (ImportError, AttributeError):
10341034
yield
1035+
1036+
1037+
def print_mesh_axes_info(mesh: jax.sharding.Mesh):
1038+
"""Prints all mesh axes and their sizes in a single comma-separated line."""
1039+
if not mesh.shape:
1040+
max_logging.info("Mesh Axes: (Empty Mesh)")
1041+
return
1042+
1043+
axis_info = [
1044+
f"{axis_name}: {axis_size}"
1045+
for axis_name, axis_size in mesh.shape.items()
1046+
]
1047+
info_str = "Mesh Axes: (" + ", ".join(axis_info) + ")"
1048+
max_logging.info(info_str)

src/MaxText/maxtext_utils.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import functools
1919
import pickle
2020

21+
from collections import defaultdict
22+
2123
from flax import linen as nn
2224
from flax.linen import partitioning as nn_partitioning
2325
from flax.training import train_state
@@ -26,6 +28,7 @@
2628

2729
from jax.experimental import mesh_utils
2830
from jax.experimental.serialize_executable import deserialize_and_load
31+
from jax.sharding import PartitionSpec as P
2932

3033
import jax
3134
import jax.numpy as jnp
@@ -1204,12 +1207,109 @@ def schedule(step):
12041207
return optax.join_schedules(pieces, boundaries)
12051208

12061209

1207-
def print_state_mesh_shardings_params(state, state_sharding, mesh):
1210+
def print_state_mesh_shardings_params(
1211+
state, state_sharding, state_logical_annotations, mesh, logical_axis_rules
1212+
):
12081213
"""Print state shardings."""
1214+
if (not hasattr(state, 'params') or
1215+
not hasattr(state_sharding, 'params') or
1216+
not hasattr(state_logical_annotations, 'params')):
1217+
max_logging.warning(
1218+
"Warning: 'params' attribute missing in one of the inputs to "
1219+
"print_state_mesh_shardings_params."
1220+
)
1221+
return
1222+
12091223
leaves_params, _ = jax.tree_util.tree_flatten_with_path(state.params)
12101224
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(state_sharding.params)
1211-
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1225+
leaves_rule_values, _ = jax.tree_util.tree_flatten_with_path(
1226+
state_logical_annotations.params
1227+
)
1228+
1229+
if not len(leaves_params) == len(leaves_sharding) == len(leaves_rule_values):
1230+
max_logging.warning(
1231+
"Warning: Parameter tree structure mismatch between state, shardings,"
1232+
" and logical annotations."
1233+
)
1234+
return
1235+
1236+
# Build a reverse map (Potential Physical Axes Tuple -> List of Semantic Names)
1237+
rule_value_to_semantic = defaultdict(list)
1238+
if logical_axis_rules:
1239+
rules_iter = (
1240+
logical_axis_rules.items()
1241+
if isinstance(logical_axis_rules, dict)
1242+
else logical_axis_rules
1243+
)
1244+
for name, potentials in rules_iter:
1245+
# name: LHS for example 'embed/activation_batch
1246+
# potentials: RHS for example 'data', 'model', None, ['data', 'model']
1247+
if isinstance(potentials, str):
1248+
key = (potentials,)
1249+
elif potentials is None:
1250+
key = (None,)
1251+
elif isinstance(potentials, list):
1252+
key = tuple(potentials)
1253+
elif isinstance(potentials, tuple):
1254+
key = potentials
1255+
else:
1256+
key = (potentials,)
1257+
1258+
key = tuple(p for p in key)
1259+
rule_value_to_semantic[key].append(name)
1260+
1261+
# Header for the entire block
1262+
max_logging.info("Parameter Path")
1263+
max_logging.info("Shape")
1264+
max_logging.info("Logical Axes")
1265+
max_logging.info("Physical PartitionSpec")
1266+
max_logging.info("-" * 120)
1267+
1268+
for (path, leaf_val), (_, leaf_sharding), (_, leaf_rule_value) in zip(
1269+
leaves_params, leaves_sharding, leaves_rule_values
1270+
):
12121271
path_str = "/".join(str(p.key) for p in path)
1213-
shape = jax.typeof(leaf_val)
1272+
shape = str(jax.typeof(leaf_val))
1273+
1274+
# Physical PartitionSpec from NamedSharding
12141275
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1215-
max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}")
1276+
pspec_str = str(tuple(pspec))
1277+
1278+
def get_semantic_names(rule_val_item, rmap):
1279+
if rule_val_item is None:
1280+
key = (None,)
1281+
elif isinstance(rule_val_item, str):
1282+
key = (rule_val_item,)
1283+
elif isinstance(rule_val_item, tuple):
1284+
key = rule_val_item
1285+
else:
1286+
return f"'{str(rule_val_item)}'"
1287+
1288+
key = tuple(p for p in key)
1289+
names = rmap.get(key)
1290+
1291+
if names:
1292+
return "{" + " | ".join(sorted(list(set(names)))) + "}"
1293+
else:
1294+
# Show rule value if unmapped.
1295+
return f"'{str(key)}'"
1296+
1297+
# Logical Axes string representation
1298+
if isinstance(leaf_rule_value, P):
1299+
semantic_parts = []
1300+
for s in leaf_rule_value:
1301+
name_str = get_semantic_names(s, rule_value_to_semantic)
1302+
semantic_parts.append(str(name_str))
1303+
semantic_str = "Partitionspec(" + ", ".join(semantic_parts) + ")"
1304+
elif leaf_rule_value is None:
1305+
semantic_str = "Partitionspec(None)"
1306+
else: # Should not be common
1307+
semantic_str = str(leaf_rule_value)
1308+
1309+
# Multi-line logging for each parameter
1310+
max_logging.info(f"{path_str}")
1311+
max_logging.info(f"{shape}")
1312+
max_logging.info(f"{semantic_str}")
1313+
max_logging.info(f"{pspec_str}")
1314+
max_logging.info("-" * 120)
1315+
print(flush=True)

src/MaxText/train_compile.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def get_shaped_inputs(topology_mesh, config):
100100
shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype)
101101

102102
# Shaped state
103-
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(
103+
abstract_state, state_logical_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state(
104104
model, tx, config, example_rng, topology_mesh
105105
)
106106

@@ -109,7 +109,7 @@ def get_shaped_inputs(topology_mesh, config):
109109

110110
shaped_train_args = (abstract_state, shaped_batch, shaped_rng)
111111
shaped_train_kwargs = {}
112-
return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model
112+
return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, state_logical_annotations, model
113113

114114

115115
def jit_and_compile(
@@ -158,7 +158,13 @@ def is_oom(argv: Sequence[str]) -> bool:
158158
max_utils.print_system_information()
159159

160160
# Get shaped inputs
161-
shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config)
161+
(
162+
shaped_train_args,
163+
shaped_train_kwargs,
164+
state_mesh_shardings,
165+
_,
166+
model,
167+
) = get_shaped_inputs(topology_mesh, config)
162168

163169
# Get data sharding
164170
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
@@ -213,7 +219,13 @@ def main(argv: Sequence[str]) -> None:
213219
max_utils.print_system_information()
214220

215221
# Get shaped inputs
216-
shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config)
222+
(
223+
shaped_train_args,
224+
shaped_train_kwargs,
225+
state_mesh_shardings,
226+
state_logical_annotations,
227+
model,
228+
) = get_shaped_inputs(topology_mesh, config)
217229

218230
# Get data sharding
219231
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
@@ -227,8 +239,15 @@ def main(argv: Sequence[str]) -> None:
227239

228240
# print weights sharding info under debug sharding mode
229241
if config.debug_sharding:
230-
max_utils.print_non_trivial_mesh_axis(topology_mesh)
231-
maxtext_utils.print_state_mesh_shardings_params(shaped_train_args[0], state_mesh_shardings, topology_mesh)
242+
# max_utils.print_non_trivial_mesh_axis(topology_mesh)
243+
max_utils.print_mesh_axes_info(topology_mesh)
244+
maxtext_utils.print_state_mesh_shardings_params(
245+
shaped_train_args[0],
246+
state_mesh_shardings,
247+
state_logical_annotations,
248+
topology_mesh,
249+
config.logical_axis_rules,
250+
)
232251

233252
# Compile
234253
print("Jitting and compiling train step...", flush=True)

src/MaxText/train_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def setup_train_loop(config, recorder, devices=None):
206206
eval_data_iterator,
207207
)
208208

209-
state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
209+
state, state_mesh_annotations, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
210210
model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager
211211
)
212212

@@ -218,7 +218,13 @@ def setup_train_loop(config, recorder, devices=None):
218218
# print weights sharding info under debug sharding mode
219219
if config.debug_sharding:
220220
max_utils.print_non_trivial_mesh_axis(model.mesh)
221-
maxtext_utils.print_state_mesh_shardings_params(state, state_mesh_shardings, model.mesh)
221+
maxtext_utils.print_state_mesh_shardings_params(
222+
state,
223+
state_mesh_shardings,
224+
state_mesh_annotations,
225+
model.mesh,
226+
config.logical_axis_rules
227+
)
222228

223229
if config.use_dpo:
224230
abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)

tests/sharding_compare_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str)
9797
validate_config(config)
9898

9999
topology_mesh = get_topology_mesh(config)
100-
_, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config)
100+
_, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config)
101101
actual_json = named_shardings_to_json(state_mesh_shardings)
102102
expected_json = load_named_sharding_json(json_path)
103103

tests/sharding_dump.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def main(argv: Sequence[str]) -> None:
275275

276276
try:
277277
topology_mesh = get_topology_mesh(config)
278-
_, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config)
278+
_, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config)
279279
except: # pylint: disable=bare-except
280280
state_mesh_shardings = {}
281281

0 commit comments

Comments
 (0)