Skip to content

Commit e5229d4

Browse files
committed
add sharding debug feature
1 parent c2574ab commit e5229d4

6 files changed

Lines changed: 36 additions & 4 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,7 @@ enable_jax_profiler: False
832832
jax_profiler_port: 9999
833833

834834
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
835+
debug_sharding: False # Prints model weights sharding info
835836

836837
# Checkpoint Structured logging
837838
enable_checkpoint_cloud_logger: False

src/MaxText/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ class RunInfo(BaseModel):
243243
True,
244244
description="If True, prints the final configuration after initialization.",
245245
)
246+
debug_sharding: bool = Field(False, description="If True, print model weight sharding details.")
246247
base_output_directory: PathStr = Field("", description="Base directory for all outputs, typically a GCS path.")
247248
sharding_strategy: None | Literal["experimental"] = Field(
248249
None,

src/MaxText/max_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,13 @@ def get_batch_seq_len_for_mode(config, model_mode):
989989
return batch_size, seq_len
990990

991991

992+
def print_non_trivial_mesh_axis(mesh):
993+
"""Print mesh axis if its axis size is larger than one."""
994+
for mesh_axis, axis_size in mesh.shape.items():
995+
if axis_size > 1:
996+
print(f"{mesh_axis}: {axis_size}", flush=True)
997+
998+
992999
@contextmanager
9931000
def maybe_get_transformer_engine_context(config):
9941001
"""Runs a transformer engine context engine manager for GPUs only."""

src/MaxText/maxtext_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,3 +1142,14 @@ def schedule(step):
11421142
boundaries.append(warmup_steps + cos_steps + constant_zero_steps)
11431143

11441144
return optax.join_schedules(pieces, boundaries)
1145+
1146+
1147+
def print_state_mesh_shardings_params(state, state_sharding, mesh):
1148+
"""Print state shardings."""
1149+
leaves_params, _ = jax.tree_util.tree_flatten_with_path(state.params)
1150+
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(state_sharding.params)
1151+
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1152+
path_str = "/".join(str(p.key) for p in path)
1153+
shape = jax.typeof(leaf_val)
1154+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1155+
print(f"{path_str:.<80} {shape} {pspec}", flush=True)

src/MaxText/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
240240
_loss_fn = dpo_loss_fn
241241

242242
params = state.params
243-
244243
if config.gradient_accumulation_steps > 1:
245244
loss, aux, raw_grads = gradient_accumulation_loss_and_grad(
246245
_loss_fn,
@@ -411,6 +410,11 @@ def train_loop(config, recorder, state=None):
411410
compiled_stats = compiled.memory_analysis()
412411
max_utils.print_compiled_memory_stats(compiled_stats)
413412

413+
# print weights sharding info under debug sharding mode
414+
if config.debug_sharding:
415+
max_utils.print_non_trivial_mesh_axis(mesh)
416+
maxtext_utils.print_state_mesh_shardings_params(state, state_mesh_shardings, mesh)
417+
414418
start_step = get_first_step(state) # this is the start_step for training
415419
prof = profiler.Profiler(config, offset_step=start_step)
416420
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)

src/MaxText/train_compile.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import jax
3131
from jax.experimental.topologies import get_topology_desc
32-
from jax.sharding import Mesh
32+
from jax.sharding import Mesh, AxisType
3333
from jax.experimental.serialize_executable import serialize
3434

3535
from flax.linen import partitioning as nn_partitioning
@@ -41,7 +41,7 @@
4141
from MaxText import max_utils
4242
from MaxText import pyconfig
4343
from MaxText import sharding
44-
from MaxText.common_types import MODEL_MODE_TRAIN
44+
from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode
4545
from MaxText.layers import models
4646
from MaxText.layers import quantizations
4747
from MaxText.utils import gcs_utils
@@ -77,8 +77,11 @@ def get_topology_mesh(config):
7777
num_slices=config.compile_topology_num_slices,
7878
wrap=target_hardware.wrap,
7979
).devices
80+
if config.shard_mode == ShardMode.EXPLICIT:
81+
jax.config.update("jax_remove_size_one_mesh_axis_from_type", True)
8082
topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices)
81-
topology_mesh = Mesh(topology_device_mesh, config.mesh_axes)
83+
mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto
84+
topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes))
8285
return topology_mesh
8386

8487

@@ -236,6 +239,11 @@ def main(argv: Sequence[str]) -> None:
236239
)
237240
print("Jitting and compilation complete!", flush=True)
238241

242+
# print weights sharding info under debug sharding mode
243+
if config.debug_sharding:
244+
max_utils.print_non_trivial_mesh_axis(topology_mesh)
245+
maxtext_utils.print_state_mesh_shardings_params(shaped_train_args[0], state_mesh_shardings, topology_mesh)
246+
239247
# Serialize and save the compiled object
240248
if config.compiled_trainstep_file != "":
241249
print("Saving compiled object...")

0 commit comments

Comments
 (0)