Skip to content

Commit fa814f7

Browse files
committed
add sharding debug feature
1 parent c2574ab commit fa814f7

5 files changed

Lines changed: 38 additions & 4 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,7 @@ enable_jax_profiler: False
832832
jax_profiler_port: 9999
833833

834834
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
835+
debug_sharding: False # Prints model weights sharding info under explicit shard mode
835836

836837
# Checkpoint Structured logging
837838
enable_checkpoint_cloud_logger: False

src/MaxText/configs/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ class RunInfo(BaseModel):
243243
True,
244244
description="If True, prints the final configuration after initialization.",
245245
)
246+
debug_sharding: bool = Field(
247+
False, description="If True, print model weight sharding details when using explicit shard mode."
248+
)
246249
base_output_directory: PathStr = Field("", description="Base directory for all outputs, typically a GCS path.")
247250
sharding_strategy: None | Literal["experimental"] = Field(
248251
None,
@@ -2052,6 +2055,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
20522055
)
20532056
if self.quantization:
20542057
raise ValueError("Quantization is not supported with 'explicit' sharding.")
2058+
else:
2059+
if self.debug_sharding:
2060+
raise ValueError("Debug sharding function only works with explicit shard mode. Please set shard_mode=explicit.")
20552061
if (
20562062
self.per_device_batch_size > 0
20572063
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0

src/MaxText/max_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,21 @@ def get_batch_seq_len_for_mode(config, model_mode):
989989
return batch_size, seq_len
990990

991991

992+
def print_non_trivial_mesh_axis(mesh):
993+
"""Print mesh axis if its axis size is larger than one."""
994+
for mesh_axis, axis_size in mesh.shape.items():
995+
if axis_size > 1:
996+
print(f"{mesh_axis}: {axis_size}", flush=True)
997+
998+
999+
def print_params_sharding_shapes(tree):
1000+
"""Print flatten tree and their leaf shapes."""
1001+
leaves, _ = jax.tree_util.tree_flatten_with_path(tree)
1002+
for path, leaf in leaves:
1003+
path_str = "/".join(str(p.key) for p in path)
1004+
print(f"{path_str:.<80} {jax.typeof(leaf)}", flush=True)
1005+
1006+
9921007
@contextmanager
9931008
def maybe_get_transformer_engine_context(config):
9941009
"""Runs a transformer engine context engine manager for GPUs only."""

src/MaxText/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
240240
_loss_fn = dpo_loss_fn
241241

242242
params = state.params
243-
244243
if config.gradient_accumulation_steps > 1:
245244
loss, aux, raw_grads = gradient_accumulation_loss_and_grad(
246245
_loss_fn,
@@ -411,6 +410,11 @@ def train_loop(config, recorder, state=None):
411410
compiled_stats = compiled.memory_analysis()
412411
max_utils.print_compiled_memory_stats(compiled_stats)
413412

413+
# print weights sharding info under debug sharding mode
414+
if config.debug_sharding:
415+
max_utils.print_non_trivial_mesh_axis(mesh)
416+
max_utils.print_params_sharding_shapes(state.params)
417+
414418
start_step = get_first_step(state) # this is the start_step for training
415419
prof = profiler.Profiler(config, offset_step=start_step)
416420
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)

src/MaxText/train_compile.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import jax
3131
from jax.experimental.topologies import get_topology_desc
32-
from jax.sharding import Mesh
32+
from jax.sharding import Mesh, AxisType
3333
from jax.experimental.serialize_executable import serialize
3434

3535
from flax.linen import partitioning as nn_partitioning
@@ -41,7 +41,7 @@
4141
from MaxText import max_utils
4242
from MaxText import pyconfig
4343
from MaxText import sharding
44-
from MaxText.common_types import MODEL_MODE_TRAIN
44+
from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode
4545
from MaxText.layers import models
4646
from MaxText.layers import quantizations
4747
from MaxText.utils import gcs_utils
@@ -77,8 +77,11 @@ def get_topology_mesh(config):
7777
num_slices=config.compile_topology_num_slices,
7878
wrap=target_hardware.wrap,
7979
).devices
80+
if config.shard_mode == ShardMode.EXPLICIT:
81+
jax.config.update("jax_remove_size_one_mesh_axis_from_type", True)
8082
topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices)
81-
topology_mesh = Mesh(topology_device_mesh, config.mesh_axes)
83+
mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto
84+
topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes))
8285
return topology_mesh
8386

8487

@@ -236,6 +239,11 @@ def main(argv: Sequence[str]) -> None:
236239
)
237240
print("Jitting and compilation complete!", flush=True)
238241

242+
# print weights sharding info under debug sharding mode
243+
if config.debug_sharding:
244+
max_utils.print_non_trivial_mesh_axis(topology_mesh)
245+
max_utils.print_params_sharding_shapes(shaped_train_args[0].params)
246+
239247
# Serialize and save the compiled object
240248
if config.compiled_trainstep_file != "":
241249
print("Saving compiled object...")

0 commit comments

Comments
 (0)