Skip to content

Commit 0627ad6

Browse files
committed
add sharding debug feature
1 parent c2574ab commit 0627ad6

6 files changed

Lines changed: 37 additions & 3 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,7 @@ enable_jax_profiler: False
832832
jax_profiler_port: 9999
833833

834834
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
835+
debug_sharding: False # Prints model weights sharding info
835836

836837
# Checkpoint Structured logging
837838
enable_checkpoint_cloud_logger: False

src/MaxText/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ class RunInfo(BaseModel):
243243
True,
244244
description="If True, prints the final configuration after initialization.",
245245
)
246+
debug_sharding: bool = Field(False, description="If True, print model weight sharding details.")
246247
base_output_directory: PathStr = Field("", description="Base directory for all outputs, typically a GCS path.")
247248
sharding_strategy: None | Literal["experimental"] = Field(
248249
None,

src/MaxText/max_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,13 @@ def get_batch_seq_len_for_mode(config, model_mode):
989989
return batch_size, seq_len
990990

991991

992+
def print_non_trivial_mesh_axis(mesh):
993+
"""Print mesh axis if its axis size is larger than one."""
994+
for mesh_axis, axis_size in mesh.shape.items():
995+
if axis_size > 1:
996+
print(f"{mesh_axis}: {axis_size}", flush=True)
997+
998+
992999
@contextmanager
9931000
def maybe_get_transformer_engine_context(config):
9941001
"""Runs a transformer engine context engine manager for GPUs only."""

src/MaxText/maxtext_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,3 +1142,14 @@ def schedule(step):
11421142
boundaries.append(warmup_steps + cos_steps + constant_zero_steps)
11431143

11441144
return optax.join_schedules(pieces, boundaries)
1145+
1146+
1147+
def print_state_mesh_shardings_params(state, state_sharding, mesh):
1148+
"""Print state shardings."""
1149+
leaves_params, _ = jax.tree_util.tree_flatten_with_path(state.params)
1150+
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(state_sharding.params)
1151+
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1152+
path_str = "/".join(str(p.key) for p in path)
1153+
shape = jax.typeof(leaf_val)
1154+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1155+
print(f"{path_str:.<80} {shape} {pspec}", flush=True)

src/MaxText/train_compile.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import jax
3131
from jax.experimental.topologies import get_topology_desc
32-
from jax.sharding import Mesh
32+
from jax.sharding import Mesh, AxisType
3333
from jax.experimental.serialize_executable import serialize
3434

3535
from flax.linen import partitioning as nn_partitioning
@@ -41,7 +41,7 @@
4141
from MaxText import max_utils
4242
from MaxText import pyconfig
4343
from MaxText import sharding
44-
from MaxText.common_types import MODEL_MODE_TRAIN
44+
from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode
4545
from MaxText.layers import models
4646
from MaxText.layers import quantizations
4747
from MaxText.utils import gcs_utils
@@ -77,8 +77,11 @@ def get_topology_mesh(config):
7777
num_slices=config.compile_topology_num_slices,
7878
wrap=target_hardware.wrap,
7979
).devices
80+
if config.shard_mode == ShardMode.EXPLICIT:
81+
jax.config.update("jax_remove_size_one_mesh_axis_from_type", True)
8082
topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices)
81-
topology_mesh = Mesh(topology_device_mesh, config.mesh_axes)
83+
mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto
84+
topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes))
8285
return topology_mesh
8386

8487

@@ -221,6 +224,11 @@ def main(argv: Sequence[str]) -> None:
221224
)
222225
)
223226

227+
# print weights sharding info under debug sharding mode
228+
if config.debug_sharding:
229+
max_utils.print_non_trivial_mesh_axis(topology_mesh)
230+
maxtext_utils.print_state_mesh_shardings_params(shaped_train_args[0], state_mesh_shardings, topology_mesh)
231+
224232
# Compile
225233
print("Jitting and compiling train step...", flush=True)
226234
compiled = jit_and_compile(

src/MaxText/train_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax
2020
from MaxText import checkpointing
2121
from MaxText import max_logging
22+
from MaxText import max_utils
2223
from MaxText import maxtext_utils
2324
from MaxText import sharding
2425
from MaxText import optimizers
@@ -213,6 +214,11 @@ def setup_train_loop(config, recorder, devices=None):
213214
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage
214215
sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance)
215216

217+
# print weights sharding info under debug sharding mode
218+
if config.debug_sharding:
219+
max_utils.print_non_trivial_mesh_axis(model.mesh)
220+
maxtext_utils.print_state_mesh_shardings_params(state, state_mesh_shardings, model.mesh)
221+
216222
if config.use_dpo:
217223
abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)
218224
max_logging.log(

0 commit comments

Comments
 (0)