|
29 | 29 |
|
30 | 30 | import jax |
31 | 31 | from jax.experimental.topologies import get_topology_desc |
32 | | -from jax.sharding import Mesh |
| 32 | +from jax.sharding import Mesh, AxisType |
33 | 33 | from jax.experimental.serialize_executable import serialize |
34 | 34 |
|
35 | 35 | from flax.linen import partitioning as nn_partitioning |
|
41 | 41 | from MaxText import max_utils |
42 | 42 | from MaxText import pyconfig |
43 | 43 | from MaxText import sharding |
44 | | -from MaxText.common_types import MODEL_MODE_TRAIN |
| 44 | +from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode |
45 | 45 | from MaxText.layers import models |
46 | 46 | from MaxText.layers import quantizations |
47 | 47 | from MaxText.utils import gcs_utils |
@@ -77,8 +77,11 @@ def get_topology_mesh(config): |
77 | 77 | num_slices=config.compile_topology_num_slices, |
78 | 78 | wrap=target_hardware.wrap, |
79 | 79 | ).devices |
| 80 | + if config.shard_mode == ShardMode.EXPLICIT: |
| 81 | + jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) |
80 | 82 | topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices) |
81 | | - topology_mesh = Mesh(topology_device_mesh, config.mesh_axes) |
| 83 | + mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto |
| 84 | + topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes)) |
82 | 85 | return topology_mesh |
83 | 86 |
|
84 | 87 |
|
@@ -221,6 +224,11 @@ def main(argv: Sequence[str]) -> None: |
221 | 224 | ) |
222 | 225 | ) |
223 | 226 |
|
| 227 | + # print weights sharding info under debug sharding mode |
| 228 | + if config.debug_sharding: |
| 229 | + max_utils.print_non_trivial_mesh_axis(topology_mesh) |
| 230 | + maxtext_utils.print_state_mesh_shardings_params(shaped_train_args[0], state_mesh_shardings, topology_mesh) |
| 231 | + |
224 | 232 | # Compile |
225 | 233 | print("Jitting and compiling train step...", flush=True) |
226 | 234 | compiled = jit_and_compile( |
|
0 commit comments