Skip to content

Commit 79afcb8

Browse files
committed
add sharding debug feature
1 parent bfa20a4 commit 79afcb8

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

src/MaxText/train_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def setup_train_loop(config, recorder, devices=None):
219219
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage
220220
sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance)
221221

222+
# print weights sharding info under debug sharding mode
223+
if config.debug_sharding:
224+
max_utils.print_non_trivial_mesh_axis(model.mesh)
225+
maxtext_utils.print_state_mesh_shardings_params(state, state_mesh_shardings, model.mesh)
226+
222227
if config.use_dpo:
223228
abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)
224229
max_logging.log(

0 commit comments

Comments
 (0)