Skip to content

Commit 36029ce

Browse files
committed
add sharding debug feature
1 parent bfa20a4 commit 36029ce

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

src/MaxText/train_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,6 @@ def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, tr
109109
donate_argnums=donate_argnums,
110110
)
111111

112-
# print weights sharding info under debug sharding mode
113-
if config.debug_sharding:
114-
max_utils.print_non_trivial_mesh_axis(model.mesh)
115-
maxtext_utils.print_state_mesh_shardings_params(state, state_mesh_shardings, model.mesh)
116-
117112
return p_train_step
118113

119114

@@ -219,6 +214,11 @@ def setup_train_loop(config, recorder, devices=None):
219214
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage
220215
sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance)
221216

217+
# print weights sharding info under debug sharding mode
218+
if config.debug_sharding:
219+
max_utils.print_non_trivial_mesh_axis(model.mesh)
220+
maxtext_utils.print_state_mesh_shardings_params(state, state_mesh_shardings, model.mesh)
221+
222222
if config.use_dpo:
223223
abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)
224224
max_logging.log(

0 commit comments

Comments
 (0)