Skip to content

Commit 6433043

Browse files
author
Sharon Yu
committed
fix format issue
1 parent a8f1147 commit 6433043

4 files changed

Lines changed: 25 additions & 40 deletions

File tree

src/MaxText/max_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,9 +1040,6 @@ def print_mesh_axes_info(mesh: jax.sharding.Mesh):
10401040
max_logging.info("Mesh Axes: (Empty Mesh)")
10411041
return
10421042

1043-
axis_info = [
1044-
f"{axis_name}: {axis_size}"
1045-
for axis_name, axis_size in mesh.shape.items()
1046-
]
1043+
axis_info = [f"{axis_name}: {axis_size}" for axis_name, axis_size in mesh.shape.items()]
10471044
info_str = "Mesh Axes: (" + ", ".join(axis_info) + ")"
10481045
max_logging.info(info_str)

src/MaxText/maxtext_utils.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,40 +1207,32 @@ def schedule(step):
12071207
return optax.join_schedules(pieces, boundaries)
12081208

12091209

1210-
def print_state_mesh_shardings_params(
1211-
state, state_sharding, state_logical_annotations, mesh, logical_axis_rules
1212-
):
1210+
def print_state_mesh_shardings_params(state, state_sharding, state_logical_annotations, mesh, logical_axis_rules):
12131211
"""Print state shardings."""
1214-
if (not hasattr(state, 'params') or
1215-
not hasattr(state_sharding, 'params') or
1216-
not hasattr(state_logical_annotations, 'params')):
1212+
if (
1213+
not hasattr(state, "params")
1214+
or not hasattr(state_sharding, "params")
1215+
or not hasattr(state_logical_annotations, "params")
1216+
):
12171217
max_logging.warning(
1218-
"Warning: 'params' attribute missing in one of the inputs to "
1219-
"print_state_mesh_shardings_params."
1218+
"Warning: 'params' attribute missing in one of the inputs to " "print_state_mesh_shardings_params."
12201219
)
12211220
return
12221221

12231222
leaves_params, _ = jax.tree_util.tree_flatten_with_path(state.params)
12241223
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(state_sharding.params)
1225-
leaves_rule_values, _ = jax.tree_util.tree_flatten_with_path(
1226-
state_logical_annotations.params
1227-
)
1224+
leaves_rule_values, _ = jax.tree_util.tree_flatten_with_path(state_logical_annotations.params)
12281225

12291226
if not len(leaves_params) == len(leaves_sharding) == len(leaves_rule_values):
12301227
max_logging.warning(
1231-
"Warning: Parameter tree structure mismatch between state, shardings,"
1232-
" and logical annotations."
1228+
"Warning: Parameter tree structure mismatch between state, shardings," " and logical annotations."
12331229
)
12341230
return
12351231

12361232
# Build a reverse map (Potential Physical Axes Tuple -> List of Semantic Names)
12371233
rule_value_to_semantic = defaultdict(list)
12381234
if logical_axis_rules:
1239-
rules_iter = (
1240-
logical_axis_rules.items()
1241-
if isinstance(logical_axis_rules, dict)
1242-
else logical_axis_rules
1243-
)
1235+
rules_iter = logical_axis_rules.items() if isinstance(logical_axis_rules, dict) else logical_axis_rules
12441236
for name, potentials in rules_iter:
12451237
# name: LHS for example 'embed/activation_batch
12461238
# potentials: RHS for example 'data', 'model', None, ['data', 'model']
@@ -1289,7 +1281,7 @@ def get_semantic_names(rule_val_item, rmap):
12891281
names = rmap.get(key)
12901282

12911283
if names:
1292-
return "{" + " | ".join(sorted(list(set(names)))) + "}"
1284+
return "{" + " | ".join(sorted(list(set(names)))) + "}"
12931285
else:
12941286
# Show rule value if unmapped.
12951287
return f"'{str(key)}'"
@@ -1302,8 +1294,8 @@ def get_semantic_names(rule_val_item, rmap):
13021294
semantic_parts.append(str(name_str))
13031295
semantic_str = "Partitionspec(" + ", ".join(semantic_parts) + ")"
13041296
elif leaf_rule_value is None:
1305-
semantic_str = "Partitionspec(None)"
1306-
else: # Should not be common
1297+
semantic_str = "Partitionspec(None)"
1298+
else: # Should not be common
13071299
semantic_str = str(leaf_rule_value)
13081300

13091301
# Multi-line logging for each parameter

src/MaxText/train_compile.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,11 @@ def is_oom(argv: Sequence[str]) -> bool:
159159

160160
# Get shaped inputs
161161
(
162-
shaped_train_args,
163-
shaped_train_kwargs,
164-
state_mesh_shardings,
165-
_,
166-
model,
162+
shaped_train_args,
163+
shaped_train_kwargs,
164+
state_mesh_shardings,
165+
_,
166+
model,
167167
) = get_shaped_inputs(topology_mesh, config)
168168

169169
# Get data sharding
@@ -220,11 +220,11 @@ def main(argv: Sequence[str]) -> None:
220220

221221
# Get shaped inputs
222222
(
223-
shaped_train_args,
224-
shaped_train_kwargs,
225-
state_mesh_shardings,
226-
state_logical_annotations,
227-
model,
223+
shaped_train_args,
224+
shaped_train_kwargs,
225+
state_mesh_shardings,
226+
state_logical_annotations,
227+
model,
228228
) = get_shaped_inputs(topology_mesh, config)
229229

230230
# Get data sharding

src/MaxText/train_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,7 @@ def setup_train_loop(config, recorder, devices=None):
219219
if config.debug_sharding:
220220
max_utils.print_non_trivial_mesh_axis(model.mesh)
221221
maxtext_utils.print_state_mesh_shardings_params(
222-
state,
223-
state_mesh_shardings,
224-
state_mesh_annotations,
225-
model.mesh,
226-
config.logical_axis_rules
222+
state, state_mesh_shardings, state_mesh_annotations, model.mesh, config.logical_axis_rules
227223
)
228224

229225
if config.use_dpo:

0 commit comments

Comments
 (0)